:ticket:`4926`
+.. _change_7433:
+
+Session raises proactively when illegal concurrent or reentrant access is detected
+----------------------------------------------------------------------------------
+
+The :class:`_orm.Session` can now trap more errors related to illegal concurrent
+state changes within multithreaded or other concurrent scenarios as well as for
+event hooks which perform unexpected state changes.
+
+One error that's been known to occur when a :class:`_orm.Session` is used in
+multiple threads simultaneously is
+``AttributeError: 'NoneType' object has no attribute 'twophase'``, which is
+completely cryptic. This error occurs when a thread calls
+:meth:`_orm.Session.commit` which internally invokes the
+:meth:`_orm.SessionTransaction.close` method to end the transactional context,
+at the same time that another thread is in progress running a query
+as from :meth:`_orm.Session.execute`. Within :meth:`_orm.Session.execute`,
+the internal method that acquires a database connection for the current
+transaction first begins by asserting that the session is "active", but
+after this assertion passes, the concurrent call to :meth:`_orm.Session.close`
+interferes with this state which leads to the undefined condition above.
+
+The change applies guards to all state-changing methods surrounding the
+:class:`_orm.SessionTransaction` object so that in the above case, the
+:meth:`_orm.Session.commit` method will instead fail as it will seek to change
+the state to one that is disallowed for the duration of the already-in-progress
+method that wants to get the current connection to run a database query.
+
+Using the test script illustrated at :ticket:`7433`, the previous
+error case looks like::
+
+ Traceback (most recent call last):
+ File "/home/classic/dev/sqlalchemy/test3.py", line 30, in worker
+ sess.execute(select(A)).all()
+ File "/home/classic/tmp/sqlalchemy/lib/sqlalchemy/orm/session.py", line 1691, in execute
+ conn = self._connection_for_bind(bind)
+ File "/home/classic/tmp/sqlalchemy/lib/sqlalchemy/orm/session.py", line 1532, in _connection_for_bind
+ return self._transaction._connection_for_bind(
+ File "/home/classic/tmp/sqlalchemy/lib/sqlalchemy/orm/session.py", line 754, in _connection_for_bind
+ if self.session.twophase and self._parent is None:
+ AttributeError: 'NoneType' object has no attribute 'twophase'
+
+Where the ``_connection_for_bind()`` method isn't able to continue since
+concurrent access placed it into an invalid state. Using the new approach, the
+originator of the state change throws the error instead::
+
+ File "/home/classic/dev/sqlalchemy/lib/sqlalchemy/orm/session.py", line 1785, in close
+ self._close_impl(invalidate=False)
+ File "/home/classic/dev/sqlalchemy/lib/sqlalchemy/orm/session.py", line 1827, in _close_impl
+ transaction.close(invalidate)
+ File "<string>", line 2, in close
+ File "/home/classic/dev/sqlalchemy/lib/sqlalchemy/orm/session.py", line 506, in _go
+ raise sa_exc.InvalidRequestError(
+ sqlalchemy.exc.InvalidRequestError: Method 'close()' can't be called here;
+ method '_connection_for_bind()' is already in progress and this would cause
+ an unexpected state change to symbol('CLOSED')
+
+The state transition checks intentionally don't use explicit locks to detect
+concurrent thread activity, instead relying upon simple attribute set / value
+test operations that inherently fail when unexpected concurrent changes occur.
+The rationale is that the approach can detect illegal state changes that occur
+entirely within a single thread, such as an event handler that runs on session
+transaction events calls a state-changing method that's not expected, or under
+asyncio if a particular :class:`_orm.Session` were shared among multiple
+asyncio tasks, as well as when using patching-style concurrency approaches
+such as gevent.
+
+:ticket:`7433`
+
.. _migration_20_overview:
--- /dev/null
+.. change::
+ :tags: feature, orm
+ :tickets: 7433
+
+ The :class:`_orm.Session` (and by extension :class:`.AsyncSession`) now has
+ new state-tracking functionality that will proactively trap any unexpected
+ state changes which occur as a particular transactional method proceeds.
+ This is to allow situations where the :class:`_orm.Session` is being used
+ in a thread-unsafe manner, where event hooks or similar may be calling
+ unexpected methods within operations, as well as potentially under other
+ concurrency situations such as asyncio or gevent to raise an informative
+ message when the illegal access first occurs, rather than passing silently
+ leading to secondary failures due to the :class:`_orm.Session` being in an
+ invalid state.
+
+ .. seealso::
+
+ :ref:`change_7433`
\ No newline at end of file
"""
+class IllegalStateChangeError(InvalidRequestError):
+ """An object that tracks state encountered an illegal state change
+ of some kind.
+
+ .. versionadded:: 2.0
+
+ """
+
+
class NoInspectionAvailable(InvalidRequestError):
"""A subject passed to :func:`sqlalchemy.inspection.inspect` produced
no context for inspection."""
from .base import object_mapper
from .base import object_state
from .base import state_str
+from .state_changes import _StateChange
+from .state_changes import _StateChangeState
+from .state_changes import _StateChangeStates
from .unitofwork import UOWTransaction
from .. import engine
from .. import exc as sa_exc
return object_session(instance)
-ACTIVE = util.symbol("ACTIVE")
-PREPARED = util.symbol("PREPARED")
-COMMITTED = util.symbol("COMMITTED")
-DEACTIVE = util.symbol("DEACTIVE")
-CLOSED = util.symbol("CLOSED")
+class SessionTransactionState(_StateChangeState):
+ ACTIVE = 1
+ PREPARED = 2
+ COMMITTED = 3
+ DEACTIVE = 4
+ CLOSED = 5
+
+
+# backwards compatibility
+ACTIVE, PREPARED, COMMITTED, DEACTIVE, CLOSED = tuple(SessionTransactionState)
class ORMExecuteState(util.MemoizedSlots):
]
-class SessionTransaction(TransactionalContext):
+class SessionTransaction(_StateChange, TransactionalContext):
"""A :class:`.Session`-level transaction.
:class:`.SessionTransaction` is produced from the
self.nested = nested
if nested:
self._previous_nested_transaction = session._nested_transaction
- self._state = ACTIVE
+ self._state = SessionTransactionState.ACTIVE
if not parent and nested:
raise sa_exc.InvalidRequestError(
"Can't start a SAVEPOINT transaction when no existing "
self.session.dispatch.after_transaction_create(self.session, self)
+ def _raise_for_prerequisite_state(self, operation_name, state):
+ if state is SessionTransactionState.DEACTIVE:
+ if self._rollback_exception:
+ raise sa_exc.PendingRollbackError(
+ "This Session's transaction has been rolled back "
+ "due to a previous exception during flush."
+ " To begin a new transaction with this Session, "
+ "first issue Session.rollback()."
+ f" Original exception was: {self._rollback_exception}",
+ code="7s2a",
+ )
+ else:
+ raise sa_exc.InvalidRequestError(
+ "This session is in 'inactive' state, due to the "
+ "SQL transaction being rolled back; no further SQL "
+ "can be emitted within this transaction."
+ )
+ elif state is SessionTransactionState.CLOSED:
+ raise sa_exc.ResourceClosedError("This transaction is closed")
+ else:
+ raise sa_exc.InvalidRequestError(
+ f"This session is in '{state.name.lower()}' state; no "
+ "further SQL can be emitted within this transaction."
+ )
+
@property
def parent(self):
"""The parent :class:`.SessionTransaction` of this
@property
def is_active(self):
- return self.session is not None and self._state is ACTIVE
-
- def _assert_active(
- self,
- prepared_ok=False,
- rollback_ok=False,
- deactive_ok=False,
- closed_msg="This transaction is closed",
- ):
- if self._state is COMMITTED:
- raise sa_exc.InvalidRequestError(
- "This session is in 'committed' state; no further "
- "SQL can be emitted within this transaction."
- )
- elif self._state is PREPARED:
- if not prepared_ok:
- raise sa_exc.InvalidRequestError(
- "This session is in 'prepared' state; no further "
- "SQL can be emitted within this transaction."
- )
- elif self._state is DEACTIVE:
- if not deactive_ok and not rollback_ok:
- if self._rollback_exception:
- raise sa_exc.PendingRollbackError(
- "This Session's transaction has been rolled back "
- "due to a previous exception during flush."
- " To begin a new transaction with this Session, "
- "first issue Session.rollback()."
- " Original exception was: %s"
- % self._rollback_exception,
- code="7s2a",
- )
- elif not deactive_ok:
- raise sa_exc.InvalidRequestError(
- "This session is in 'inactive' state, due to the "
- "SQL transaction being rolled back; no further "
- "SQL can be emitted within this transaction."
- )
- elif self._state is CLOSED:
- raise sa_exc.ResourceClosedError(closed_msg)
+ return (
+ self.session is not None
+ and self._state is SessionTransactionState.ACTIVE
+ )
@property
def _is_transaction_boundary(self):
return self.nested or not self._parent
+ @_StateChange.declare_states(
+ (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE
+ )
def connection(self, bindkey, execution_options=None, **kwargs):
- self._assert_active()
bind = self.session.get_bind(bindkey, **kwargs)
return self._connection_for_bind(bind, execution_options)
+ @_StateChange.declare_states(
+ (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE
+ )
def _begin(self, nested=False):
- self._assert_active()
return SessionTransaction(self.session, self, nested=nested)
def _iterate_self_and_parents(self, upto=None):
self._parent._deleted.update(self._deleted)
self._parent._key_switches.update(self._key_switches)
+ @_StateChange.declare_states(
+ (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE
+ )
def _connection_for_bind(self, bind, execution_options):
- self._assert_active()
if bind in self._connections:
if execution_options:
)
self._prepare_impl()
+ @_StateChange.declare_states(
+ (SessionTransactionState.ACTIVE,), SessionTransactionState.PREPARED
+ )
def _prepare_impl(self):
- self._assert_active()
+
if self._parent is None or self.nested:
self.session.dispatch.before_commit(self.session)
with util.safe_reraise():
self.rollback()
- self._state = PREPARED
+ self._state = SessionTransactionState.PREPARED
+ @_StateChange.declare_states(
+ (SessionTransactionState.ACTIVE, SessionTransactionState.PREPARED),
+ SessionTransactionState.CLOSED,
+ )
def commit(self, _to_root=False):
- self._assert_active(prepared_ok=True)
- if self._state is not PREPARED:
- self._prepare_impl()
+ if self._state is not SessionTransactionState.PREPARED:
+ with self._expect_state(SessionTransactionState.PREPARED):
+ self._prepare_impl()
if self._parent is None or self.nested:
for conn, trans, should_commit, autoclose in set(
if should_commit:
trans.commit()
- self._state = COMMITTED
+ self._state = SessionTransactionState.COMMITTED
self.session.dispatch.after_commit(self.session)
self._remove_snapshot()
- self.close()
+ with self._expect_state(SessionTransactionState.CLOSED):
+ self.close()
if _to_root and self._parent:
return self._parent.commit(_to_root=True)
return self._parent
+ @_StateChange.declare_states(
+ (
+ SessionTransactionState.ACTIVE,
+ SessionTransactionState.DEACTIVE,
+ SessionTransactionState.PREPARED,
+ ),
+ SessionTransactionState.CLOSED,
+ )
def rollback(self, _capture_exception=False, _to_root=False):
- self._assert_active(prepared_ok=True, rollback_ok=True)
stx = self.session._transaction
if stx is not self:
boundary = self
rollback_err = None
- if self._state in (ACTIVE, PREPARED):
+ if self._state in (
+ SessionTransactionState.ACTIVE,
+ SessionTransactionState.PREPARED,
+ ):
for transaction in self._iterate_self_and_parents():
if transaction._parent is None or transaction.nested:
try:
for t in set(transaction._connections.values()):
t[1].rollback()
- transaction._state = DEACTIVE
+ transaction._state = SessionTransactionState.DEACTIVE
self.session.dispatch.after_rollback(self.session)
except:
rollback_err = sys.exc_info()
finally:
- transaction._state = DEACTIVE
+ transaction._state = SessionTransactionState.DEACTIVE
transaction._restore_snapshot(
dirty_only=transaction.nested
)
boundary = transaction
break
else:
- transaction._state = DEACTIVE
+ transaction._state = SessionTransactionState.DEACTIVE
sess = self.session
)
boundary._restore_snapshot(dirty_only=boundary.nested)
- self.close()
+ with self._expect_state(SessionTransactionState.CLOSED):
+ self.close()
if self._parent and _capture_exception:
self._parent._rollback_exception = sys.exc_info()[1]
return self._parent.rollback(_to_root=True)
return self._parent
+ @_StateChange.declare_states(
+ _StateChangeStates.ANY, SessionTransactionState.CLOSED
+ )
def close(self, invalidate=False):
if self.nested:
self.session._nested_transaction = (
if autoclose:
connection.close()
- self._state = CLOSED
- self.session.dispatch.after_transaction_end(self.session, self)
+ self._state = SessionTransactionState.CLOSED
+ sess = self.session
self.session = None
self._connections = None
+ sess.dispatch.after_transaction_end(sess, self)
+
def _get_subject(self):
return self.session
def _transaction_is_active(self):
- return self._state is ACTIVE
+ return self._state is SessionTransactionState.ACTIVE
def _transaction_is_closed(self):
- return self._state is CLOSED
+ return self._state is SessionTransactionState.CLOSED
def _rollback_can_be_called(self):
return self._state not in (COMMITTED, CLOSED)
--- /dev/null
+# orm/state_changes.py
+# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""State tracking utilities used by :class:`_orm.Session`.
+
+"""
+
+import contextlib
+from enum import Enum
+from typing import Any
+from typing import Callable
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+from .. import exc as sa_exc
+from .. import util
+from ..util.typing import Literal
+
+
+class _StateChangeState(Enum):
+ pass
+
+
+class _StateChangeStates(_StateChangeState):
+ ANY = 1
+ NO_CHANGE = 2
+ CHANGE_IN_PROGRESS = 3
+
+
+class _StateChange:
+ """Supplies state assertion decorators.
+
+ The current use case is for the :class:`_orm.SessionTransaction` class. The
+ :class:`_StateChange` class itself is agnostic of the
+ :class:`_orm.SessionTransaction` class so could in theory be generalized
+ for other systems as well.
+
+ """
+
+ _next_state: _StateChangeState = _StateChangeStates.ANY
+ _state: _StateChangeState = _StateChangeStates.NO_CHANGE
+ _current_fn: Optional[Callable] = None
+
+ def _raise_for_prerequisite_state(self, operation_name, state):
+ raise sa_exc.IllegalStateChangeError(
+ f"Can't run operation '{operation_name}()' when Session "
+ f"is in state {state!r}"
+ )
+
+ @classmethod
+ def declare_states(
+ cls,
+ prerequisite_states: Union[
+ Literal[_StateChangeStates.ANY], Tuple[_StateChangeState, ...]
+ ],
+ moves_to: _StateChangeState,
+ ) -> Callable[..., Any]:
+ """Method decorator declaring valid states.
+
+ :param prerequisite_states: sequence of acceptable prerequisite
+ states. Can be the single constant _State.ANY to indicate no
+ prerequisite state
+
+ :param moves_to: the expected state at the end of the method, assuming
+ no exceptions raised. Can be the constant _State.NO_CHANGE to
+ indicate state should not change at the end of the method.
+
+ """
+ assert prerequisite_states, "no prequisite states sent"
+ has_prerequisite_states = (
+ prerequisite_states is not _StateChangeStates.ANY
+ )
+
+ expect_state_change = moves_to is not _StateChangeStates.NO_CHANGE
+
+ @util.decorator
+ def _go(fn, self, *arg, **kw):
+
+ current_state = self._state
+
+ if (
+ has_prerequisite_states
+ and current_state not in prerequisite_states
+ ):
+ self._raise_for_prerequisite_state(fn.__name__, current_state)
+
+ next_state = self._next_state
+ existing_fn = self._current_fn
+ expect_state = moves_to if expect_state_change else current_state
+
+ if (
+ # destination states are restricted
+ next_state is not _StateChangeStates.ANY
+ # method seeks to change state
+ and expect_state_change
+ # destination state incorrect
+ and next_state is not expect_state
+ ):
+ if existing_fn and next_state in (
+ _StateChangeStates.NO_CHANGE,
+ _StateChangeStates.CHANGE_IN_PROGRESS,
+ ):
+ raise sa_exc.IllegalStateChangeError(
+ f"Method '{fn.__name__}()' can't be called here; "
+ f"method '{existing_fn.__name__}()' is already "
+ f"in progress and this would cause an unexpected "
+ f"state change to {moves_to!r}"
+ )
+ else:
+ raise sa_exc.IllegalStateChangeError(
+ f"Cant run operation '{fn.__name__}()' here; "
+ f"will move to state {moves_to!r} where we are "
+ f"expecting {next_state!r}"
+ )
+
+ self._current_fn = fn
+ self._next_state = _StateChangeStates.CHANGE_IN_PROGRESS
+ try:
+ ret_value = fn(self, *arg, **kw)
+ except:
+ raise
+ else:
+ if self._state is expect_state:
+ return ret_value
+
+ if self._state is current_state:
+ raise sa_exc.IllegalStateChangeError(
+ f"Method '{fn.__name__}()' failed to "
+ "change state "
+ f"to {moves_to!r} as expected"
+ )
+ elif existing_fn:
+ raise sa_exc.IllegalStateChangeError(
+ f"While method '{existing_fn.__name__}()' was "
+ "running, "
+ f"method '{fn.__name__}()' caused an "
+ "unexpected "
+ f"state change to {self._state!r}"
+ )
+ else:
+ raise sa_exc.IllegalStateChangeError(
+ f"Method '{fn.__name__}()' caused an unexpected "
+ f"state change to {self._state!r}"
+ )
+
+ finally:
+ self._next_state = next_state
+ self._current_fn = existing_fn
+
+ return _go
+
+ @contextlib.contextmanager
+ def _expect_state(self, expected: _StateChangeState):
+ """called within a method that changes states.
+
+ method must also use the ``@declare_states()`` decorator.
+
+ """
+ assert self._next_state is _StateChangeStates.CHANGE_IN_PROGRESS, (
+ "Unexpected call to _expect_state outside of "
+ "state-changing method"
+ )
+
+ self._next_state = expected
+ try:
+ yield
+ except:
+ raise
+ else:
+ if self._state is not expected:
+ raise sa_exc.IllegalStateChangeError(
+ f"Unexpected state change to {self._state!r}"
+ )
+ finally:
+ self._next_state = _StateChangeStates.CHANGE_IN_PROGRESS
sa_exceptions.InvalidatePoolError,
sa_exceptions.TimeoutError,
sa_exceptions.InvalidRequestError,
+ sa_exceptions.IllegalStateChangeError,
sa_exceptions.NoInspectionAvailable,
sa_exceptions.PendingRollbackError,
sa_exceptions.ResourceClosedError,
--- /dev/null
+from sqlalchemy import exc as sa_exc
+from sqlalchemy.orm import state_changes
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
+from sqlalchemy.testing import fixtures
+
+
+class StateTestChange(state_changes._StateChangeState):
+ a = 1
+ b = 2
+ c = 3
+
+
+class StateMachineTest(fixtures.TestBase):
+ def test_single_change(self):
+ """test single method that declares and invokes a state change"""
+ _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE
+
+ class Machine(state_changes._StateChange):
+ @state_changes._StateChange.declare_states(
+ (StateTestChange.a, _NO_CHANGE), StateTestChange.b
+ )
+ def move_to_b(self):
+ self._state = StateTestChange.b
+
+ m = Machine()
+ eq_(m._state, _NO_CHANGE)
+ m.move_to_b()
+ eq_(m._state, StateTestChange.b)
+
+ def test_single_incorrect_change(self):
+ """test single method that declares a state change but changes to the
+ wrong state."""
+ _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE
+
+ class Machine(state_changes._StateChange):
+ @state_changes._StateChange.declare_states(
+ (StateTestChange.a, _NO_CHANGE), StateTestChange.b
+ )
+ def move_to_b(self):
+ self._state = StateTestChange.c
+
+ m = Machine()
+ eq_(m._state, _NO_CHANGE)
+ with expect_raises_message(
+ sa_exc.IllegalStateChangeError,
+ r"Method 'move_to_b\(\)' "
+ r"caused an unexpected state change to <StateTestChange.c: 3>",
+ ):
+ m.move_to_b()
+
+ def test_single_failed_to_change(self):
+ """test single method that declares a state change but didn't do
+ the change."""
+ _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE
+
+ class Machine(state_changes._StateChange):
+ @state_changes._StateChange.declare_states(
+ (StateTestChange.a, _NO_CHANGE), StateTestChange.b
+ )
+ def move_to_b(self):
+ pass
+
+ m = Machine()
+ eq_(m._state, _NO_CHANGE)
+ with expect_raises_message(
+ sa_exc.IllegalStateChangeError,
+ r"Method 'move_to_b\(\)' failed to change state "
+ "to <StateTestChange.b: 2> as "
+ "expected",
+ ):
+ m.move_to_b()
+
+ def test_change_from_sub_method_with_declaration(self):
+ """test successful state change by one method calling another that
+ does the change.
+
+ """
+ _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE
+
+ class Machine(state_changes._StateChange):
+ @state_changes._StateChange.declare_states(
+ (StateTestChange.a, _NO_CHANGE), StateTestChange.b
+ )
+ def _inner_move_to_b(self):
+ self._state = StateTestChange.b
+
+ @state_changes._StateChange.declare_states(
+ (StateTestChange.a, _NO_CHANGE), StateTestChange.b
+ )
+ def move_to_b(self):
+ with self._expect_state(StateTestChange.b):
+ self._inner_move_to_b()
+
+ m = Machine()
+ eq_(m._state, _NO_CHANGE)
+ m.move_to_b()
+ eq_(m._state, StateTestChange.b)
+
+ def test_method_and_sub_method_no_change(self):
+ """test methods that declare the state should not change"""
+ _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE
+
+ class Machine(state_changes._StateChange):
+ @state_changes._StateChange.declare_states(
+ (StateTestChange.a,), _NO_CHANGE
+ )
+ def _inner_do_nothing(self):
+ pass
+
+ @state_changes._StateChange.declare_states(
+ (StateTestChange.a,), _NO_CHANGE
+ )
+ def do_nothing(self):
+ self._inner_do_nothing()
+
+ m = Machine()
+ eq_(m._state, _NO_CHANGE)
+ m._state = StateTestChange.a
+ m.do_nothing()
+ eq_(m._state, StateTestChange.a)
+
+ def test_method_w_no_change_illegal_inner_change(self):
+ _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE
+
+ class Machine(state_changes._StateChange):
+ @state_changes._StateChange.declare_states(
+ (StateTestChange.a, _NO_CHANGE), StateTestChange.c
+ )
+ def _inner_move_to_c(self):
+ self._state = StateTestChange.c
+
+ @state_changes._StateChange.declare_states(
+ (StateTestChange.a,), _NO_CHANGE
+ )
+ def do_nothing(self):
+ self._inner_move_to_c()
+
+ m = Machine()
+ eq_(m._state, _NO_CHANGE)
+ m._state = StateTestChange.a
+
+ with expect_raises_message(
+ sa_exc.IllegalStateChangeError,
+ r"Method '_inner_move_to_c\(\)' can't be called here; "
+ r"method 'do_nothing\(\)' is already in progress and this "
+ r"would cause an unexpected state change to "
+ "<StateTestChange.c: 3>",
+ ):
+ m.do_nothing()
+ eq_(m._state, StateTestChange.a)
+
+ def test_change_from_method_sub_w_no_change(self):
+ """test methods that declare the state should not change"""
+ _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE
+
+ class Machine(state_changes._StateChange):
+ @state_changes._StateChange.declare_states(
+ (StateTestChange.a,), _NO_CHANGE
+ )
+ def _inner_do_nothing(self):
+ pass
+
+ @state_changes._StateChange.declare_states(
+ (StateTestChange.a,), StateTestChange.b
+ )
+ def move_to_b(self):
+ self._inner_do_nothing()
+ self._state = StateTestChange.b
+
+ m = Machine()
+ eq_(m._state, _NO_CHANGE)
+ m._state = StateTestChange.a
+ m.move_to_b()
+ eq_(m._state, StateTestChange.b)
+
+ def test_invalid_change_from_declared_sub_method_with_declaration(self):
+ """A method uses _expect_state() to call a sub-method, which must
+ declare that state as its destination if no exceptions are raised.
+
+ """
+ _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE
+
+ class Machine(state_changes._StateChange):
+ # method declares StateTestChange.c so can't be called under
+ # expect_state(StateTestChange.b)
+ @state_changes._StateChange.declare_states(
+ (StateTestChange.a, _NO_CHANGE), StateTestChange.c
+ )
+ def _inner_move_to_c(self):
+ self._state = StateTestChange.c
+
+ @state_changes._StateChange.declare_states(
+ (StateTestChange.a, _NO_CHANGE), StateTestChange.b
+ )
+ def move_to_b(self):
+ with self._expect_state(StateTestChange.b):
+ self._inner_move_to_c()
+
+ m = Machine()
+ eq_(m._state, _NO_CHANGE)
+ with expect_raises_message(
+ sa_exc.IllegalStateChangeError,
+ r"Cant run operation '_inner_move_to_c\(\)' here; will move "
+ r"to state <StateTestChange.c: 3> where we are "
+ "expecting <StateTestChange.b: 2>",
+ ):
+ m.move_to_b()
+
+ def test_invalid_change_from_invalid_sub_method_with_declaration(self):
+ """A method uses _expect_state() to call a sub-method, which must
+ declare that state as its destination if no exceptions are raised.
+
+ Test an error is raised if the sub-method doesn't change to the
+ correct state.
+
+ """
+ _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE
+
+ class Machine(state_changes._StateChange):
+ # method declares StateTestChange.b, but is doing the wrong
+ # change, so should fail under expect_state(StateTestChange.b)
+ @state_changes._StateChange.declare_states(
+ (StateTestChange.a, _NO_CHANGE), StateTestChange.b
+ )
+ def _inner_move_to_c(self):
+ self._state = StateTestChange.c
+
+ @state_changes._StateChange.declare_states(
+ (StateTestChange.a, _NO_CHANGE), StateTestChange.b
+ )
+ def move_to_b(self):
+ with self._expect_state(StateTestChange.b):
+ self._inner_move_to_c()
+
+ m = Machine()
+ eq_(m._state, _NO_CHANGE)
+ with expect_raises_message(
+ sa_exc.IllegalStateChangeError,
+ r"While method 'move_to_b\(\)' was running, method "
+ r"'_inner_move_to_c\(\)' caused an unexpected state change "
+ "to <StateTestChange.c: 3>",
+ ):
+ m.move_to_b()
+
+ def test_invalid_prereq_state(self):
+ _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE
+
+ class Machine(state_changes._StateChange):
+ @state_changes._StateChange.declare_states(
+ (StateTestChange.a, _NO_CHANGE), StateTestChange.b
+ )
+ def move_to_b(self):
+ self._state = StateTestChange.b
+
+ @state_changes._StateChange.declare_states(
+ (StateTestChange.c,), "d"
+ )
+ def move_to_d(self):
+ self._state = "d"
+
+ m = Machine()
+ eq_(m._state, _NO_CHANGE)
+ m.move_to_b()
+ eq_(m._state, StateTestChange.b)
+ with expect_raises_message(
+ sa_exc.IllegalStateChangeError,
+ r"Can't run operation 'move_to_d\(\)' when "
+ "Session is in state <StateTestChange.b: 2>",
+ ):
+ m.move_to_d()
+
+ def test_declare_only(self):
+ _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE
+
+ class Machine(state_changes._StateChange):
+ @state_changes._StateChange.declare_states(
+ state_changes._StateChangeStates.ANY, StateTestChange.b
+ )
+ def _inner_move_to_b(self):
+ self._state = StateTestChange.b
+
+ def move_to_b(self):
+ with self._expect_state(StateTestChange.b):
+ self._move_to_b()
+
+ m = Machine()
+ eq_(m._state, _NO_CHANGE)
+ with expect_raises_message(
+ AssertionError,
+ "Unexpected call to _expect_state outside of "
+ "state-changing method",
+ ):
+ m.move_to_b()
+
+ def test_sibling_calls_maintain_correct_state(self):
+ _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE
+
+ class Machine(state_changes._StateChange):
+ @state_changes._StateChange.declare_states(
+ state_changes._StateChangeStates.ANY, StateTestChange.c
+ )
+ def move_to_c(self):
+ self._state = StateTestChange.c
+
+ @state_changes._StateChange.declare_states(
+ state_changes._StateChangeStates.ANY, _NO_CHANGE
+ )
+ def do_nothing(self):
+ pass
+
+ m = Machine()
+ m.do_nothing()
+ eq_(m._state, _NO_CHANGE)
+ m.move_to_c()
+ eq_(m._state, StateTestChange.c)
+
+ def test_change_from_sub_method_requires_declaration(self):
+ """A method can't call another state-changing method without using
+ _expect_state() to allow the state change to occur.
+
+ """
+ _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE
+
+ class Machine(state_changes._StateChange):
+ @state_changes._StateChange.declare_states(
+ (StateTestChange.a, _NO_CHANGE), StateTestChange.b
+ )
+ def _inner_move_to_b(self):
+ self._state = StateTestChange.b
+
+ @state_changes._StateChange.declare_states(
+ (StateTestChange.a, _NO_CHANGE), StateTestChange.b
+ )
+ def move_to_b(self):
+ self._inner_move_to_b()
+
+ m = Machine()
+
+ with expect_raises_message(
+ sa_exc.IllegalStateChangeError,
+ r"Method '_inner_move_to_b\(\)' can't be called here; "
+ r"method 'move_to_b\(\)' is already in progress and this would "
+ r"cause an unexpected state change to <StateTestChange.b: 2>",
+ ):
+ m.move_to_b()