From: Mike Bayer Date: Thu, 2 Dec 2021 14:18:11 +0000 (-0500) Subject: propose concurrency check for SessionTransaction X-Git-Tag: rel_2_0_0b1~575^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=cc46ea711df77540d5d658e9c7b3ab1e88288929;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git propose concurrency check for SessionTransaction the discussion at #7387 refers to a condition that seems to happen in the wild also, such as [1] [2] [3], it's not entirely clear why this specific spot is how this occurs, however it's maybe that when the connection is being acquired from the pool, under load there might be a wait on the connection pool, leading to more time for another errant thread to be calling .close(), just a theory. in this patch we propose using decorators and context managers along with declarative state declarations to block reentrant or concurrent calls to methods that conflict with expected state changes. The :class:`_orm.Session` (and by extension :class:`.AsyncSession`) now has new state-tracking functionality that will proactively trap any unexpected state changes which occur as a particular transactional method proceeds. This is to allow situations where the :class:`_orm.Session` is being used in a thread-unsafe manner, where event hooks or similar may be calling unexpected methods within operations, as well as potentially under other concurrency situations such as asyncio or gevent to raise an informative message when the illegal access first occurs, rather than passing silently leading to secondary failures due to the :class:`_orm.Session` being in an invalid state. [1] https://stackoverflow.com/questions/25768428/sqlalchemy-connection-errors [2] https://groups.google.com/g/sqlalchemy/c/n5oVX3v4WOw [3] https://github.com/cosmicpython/code/issues/23 Fixes: #7433 Change-Id: I699b935c0ec4e5a63f12cf878af6f7a92a30a3aa --- diff --git a/doc/build/changelog/migration_20.rst b/doc/build/changelog/migration_20.rst index c7ab7f0752..72530142e8 100644 --- a/doc/build/changelog/migration_20.rst +++ b/doc/build/changelog/migration_20.rst @@ -211,6 +211,75 @@ in order to provide backend-agnostic floor division. :ticket:`4926` +.. _change_7433: + +Session raises proactively when illegal concurrent or reentrant access is detected +---------------------------------------------------------------------------------- + +The :class:`_orm.Session` can now trap more errors related to illegal concurrent +state changes within multithreaded or other concurrent scenarios as well as for +event hooks which perform unexpected state changes. + +One error that's been known to occur when a :class:`_orm.Session` is used in +multiple threads simultaneously is +``AttributeError: 'NoneType' object has no attribute 'twophase'``, which is +completely cryptic. This error occurs when a thread calls +:meth:`_orm.Session.commit` which internally invokes the +:meth:`_orm.SessionTransaction.close` method to end the transactional context, +at the same time that another thread is in progress running a query +as from :meth:`_orm.Session.execute`. Within :meth:`_orm.Session.execute`, +the internal method that acquires a database connection for the current +transaction first begins by asserting that the session is "active", but +after this assertion passes, the concurrent call to :meth:`_orm.Session.close` +interferes with this state which leads to the undefined condition above. + +The change applies guards to all state-changing methods surrounding the +:class:`_orm.SessionTransaction` object so that in the above case, the +:meth:`_orm.Session.commit` method will instead fail as it will seek to change +the state to one that is disallowed for the duration of the already-in-progress +method that wants to get the current connection to run a database query. + +Using the test script illustrated at :ticket:`7433`, the previous +error case looks like:: + + Traceback (most recent call last): + File "/home/classic/dev/sqlalchemy/test3.py", line 30, in worker + sess.execute(select(A)).all() + File "/home/classic/tmp/sqlalchemy/lib/sqlalchemy/orm/session.py", line 1691, in execute + conn = self._connection_for_bind(bind) + File "/home/classic/tmp/sqlalchemy/lib/sqlalchemy/orm/session.py", line 1532, in _connection_for_bind + return self._transaction._connection_for_bind( + File "/home/classic/tmp/sqlalchemy/lib/sqlalchemy/orm/session.py", line 754, in _connection_for_bind + if self.session.twophase and self._parent is None: + AttributeError: 'NoneType' object has no attribute 'twophase' + +Where the ``_connection_for_bind()`` method isn't able to continue since +concurrent access placed it into an invalid state. Using the new approach, the +originator of the state change throws the error instead:: + + File "/home/classic/dev/sqlalchemy/lib/sqlalchemy/orm/session.py", line 1785, in close + self._close_impl(invalidate=False) + File "/home/classic/dev/sqlalchemy/lib/sqlalchemy/orm/session.py", line 1827, in _close_impl + transaction.close(invalidate) + File "", line 2, in close + File "/home/classic/dev/sqlalchemy/lib/sqlalchemy/orm/session.py", line 506, in _go + raise sa_exc.InvalidRequestError( + sqlalchemy.exc.InvalidRequestError: Method 'close()' can't be called here; + method '_connection_for_bind()' is already in progress and this would cause + an unexpected state change to symbol('CLOSED') + +The state transition checks intentionally don't use explicit locks to detect +concurrent thread activity, instead relying upon simple attribute set / value +test operations that inherently fail when unexpected concurrent changes occur. +The rationale is that the approach can detect illegal state changes that occur +entirely within a single thread, such as an event handler that runs on session +transaction events calls a state-changing method that's not expected, or under +asyncio if a particular :class:`_orm.Session` were shared among multiple +asyncio tasks, as well as when using patching-style concurrency approaches +such as gevent. + +:ticket:`7433` + .. _migration_20_overview: diff --git a/doc/build/changelog/unreleased_20/7433.rst b/doc/build/changelog/unreleased_20/7433.rst new file mode 100644 index 0000000000..5de470e594 --- /dev/null +++ b/doc/build/changelog/unreleased_20/7433.rst @@ -0,0 +1,18 @@ +.. change:: + :tags: feature, orm + :tickets: 7433 + + The :class:`_orm.Session` (and by extension :class:`.AsyncSession`) now has + new state-tracking functionality that will proactively trap any unexpected + state changes which occur as a particular transactional method proceeds. + This is to allow situations where the :class:`_orm.Session` is being used + in a thread-unsafe manner, where event hooks or similar may be calling + unexpected methods within operations, as well as potentially under other + concurrency situations such as asyncio or gevent to raise an informative + message when the illegal access first occurs, rather than passing silently + leading to secondary failures due to the :class:`_orm.Session` being in an + invalid state. + + .. seealso:: + + :ref:`change_7433` \ No newline at end of file diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index e35c41836b..e51214fd9b 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -249,6 +249,15 @@ class InvalidRequestError(SQLAlchemyError): """ +class IllegalStateChangeError(InvalidRequestError): + """An object that tracks state encountered an illegal state change + of some kind. + + .. versionadded:: 2.0 + + """ + + class NoInspectionAvailable(InvalidRequestError): """A subject passed to :func:`sqlalchemy.inspection.inspect` produced no context for inspection.""" diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index e921bb8f0e..13fc7f22e7 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -26,6 +26,9 @@ from .base import instance_str from .base import object_mapper from .base import object_state from .base import state_str +from .state_changes import _StateChange +from .state_changes import _StateChangeState +from .state_changes import _StateChangeStates from .unitofwork import UOWTransaction from .. import engine from .. import exc as sa_exc @@ -101,11 +104,16 @@ class _SessionClassMethods: return object_session(instance) -ACTIVE = util.symbol("ACTIVE") -PREPARED = util.symbol("PREPARED") -COMMITTED = util.symbol("COMMITTED") -DEACTIVE = util.symbol("DEACTIVE") -CLOSED = util.symbol("CLOSED") +class SessionTransactionState(_StateChangeState): + ACTIVE = 1 + PREPARED = 2 + COMMITTED = 3 + DEACTIVE = 4 + CLOSED = 5 + + +# backwards compatibility +ACTIVE, PREPARED, COMMITTED, DEACTIVE, CLOSED = tuple(SessionTransactionState) class ORMExecuteState(util.MemoizedSlots): @@ -476,7 +484,7 @@ class ORMExecuteState(util.MemoizedSlots): ] -class SessionTransaction(TransactionalContext): +class SessionTransaction(_StateChange, TransactionalContext): """A :class:`.Session`-level transaction. :class:`.SessionTransaction` is produced from the @@ -532,7 +540,7 @@ class SessionTransaction(TransactionalContext): self.nested = nested if nested: self._previous_nested_transaction = session._nested_transaction - self._state = ACTIVE + self._state = SessionTransactionState.ACTIVE if not parent and nested: raise sa_exc.InvalidRequestError( "Can't start a SAVEPOINT transaction when no existing " @@ -547,6 +555,31 @@ class SessionTransaction(TransactionalContext): self.session.dispatch.after_transaction_create(self.session, self) + def _raise_for_prerequisite_state(self, operation_name, state): + if state is SessionTransactionState.DEACTIVE: + if self._rollback_exception: + raise sa_exc.PendingRollbackError( + "This Session's transaction has been rolled back " + "due to a previous exception during flush." + " To begin a new transaction with this Session, " + "first issue Session.rollback()." + f" Original exception was: {self._rollback_exception}", + code="7s2a", + ) + else: + raise sa_exc.InvalidRequestError( + "This session is in 'inactive' state, due to the " + "SQL transaction being rolled back; no further SQL " + "can be emitted within this transaction." + ) + elif state is SessionTransactionState.CLOSED: + raise sa_exc.ResourceClosedError("This transaction is closed") + else: + raise sa_exc.InvalidRequestError( + f"This session is in '{state.name.lower()}' state; no " + "further SQL can be emitted within this transaction." + ) + @property def parent(self): """The parent :class:`.SessionTransaction` of this @@ -576,58 +609,26 @@ class SessionTransaction(TransactionalContext): @property def is_active(self): - return self.session is not None and self._state is ACTIVE - - def _assert_active( - self, - prepared_ok=False, - rollback_ok=False, - deactive_ok=False, - closed_msg="This transaction is closed", - ): - if self._state is COMMITTED: - raise sa_exc.InvalidRequestError( - "This session is in 'committed' state; no further " - "SQL can be emitted within this transaction." - ) - elif self._state is PREPARED: - if not prepared_ok: - raise sa_exc.InvalidRequestError( - "This session is in 'prepared' state; no further " - "SQL can be emitted within this transaction." - ) - elif self._state is DEACTIVE: - if not deactive_ok and not rollback_ok: - if self._rollback_exception: - raise sa_exc.PendingRollbackError( - "This Session's transaction has been rolled back " - "due to a previous exception during flush." - " To begin a new transaction with this Session, " - "first issue Session.rollback()." - " Original exception was: %s" - % self._rollback_exception, - code="7s2a", - ) - elif not deactive_ok: - raise sa_exc.InvalidRequestError( - "This session is in 'inactive' state, due to the " - "SQL transaction being rolled back; no further " - "SQL can be emitted within this transaction." - ) - elif self._state is CLOSED: - raise sa_exc.ResourceClosedError(closed_msg) + return ( + self.session is not None + and self._state is SessionTransactionState.ACTIVE + ) @property def _is_transaction_boundary(self): return self.nested or not self._parent + @_StateChange.declare_states( + (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE + ) def connection(self, bindkey, execution_options=None, **kwargs): - self._assert_active() bind = self.session.get_bind(bindkey, **kwargs) return self._connection_for_bind(bind, execution_options) + @_StateChange.declare_states( + (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE + ) def _begin(self, nested=False): - self._assert_active() return SessionTransaction(self.session, self, nested=nested) def _iterate_self_and_parents(self, upto=None): @@ -718,8 +719,10 @@ class SessionTransaction(TransactionalContext): self._parent._deleted.update(self._deleted) self._parent._key_switches.update(self._key_switches) + @_StateChange.declare_states( + (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE + ) def _connection_for_bind(self, bind, execution_options): - self._assert_active() if bind in self._connections: if execution_options: @@ -792,8 +795,11 @@ class SessionTransaction(TransactionalContext): ) self._prepare_impl() + @_StateChange.declare_states( + (SessionTransactionState.ACTIVE,), SessionTransactionState.PREPARED + ) def _prepare_impl(self): - self._assert_active() + if self._parent is None or self.nested: self.session.dispatch.before_commit(self.session) @@ -822,12 +828,16 @@ class SessionTransaction(TransactionalContext): with util.safe_reraise(): self.rollback() - self._state = PREPARED + self._state = SessionTransactionState.PREPARED + @_StateChange.declare_states( + (SessionTransactionState.ACTIVE, SessionTransactionState.PREPARED), + SessionTransactionState.CLOSED, + ) def commit(self, _to_root=False): - self._assert_active(prepared_ok=True) - if self._state is not PREPARED: - self._prepare_impl() + if self._state is not SessionTransactionState.PREPARED: + with self._expect_state(SessionTransactionState.PREPARED): + self._prepare_impl() if self._parent is None or self.nested: for conn, trans, should_commit, autoclose in set( @@ -836,20 +846,28 @@ class SessionTransaction(TransactionalContext): if should_commit: trans.commit() - self._state = COMMITTED + self._state = SessionTransactionState.COMMITTED self.session.dispatch.after_commit(self.session) self._remove_snapshot() - self.close() + with self._expect_state(SessionTransactionState.CLOSED): + self.close() if _to_root and self._parent: return self._parent.commit(_to_root=True) return self._parent + @_StateChange.declare_states( + ( + SessionTransactionState.ACTIVE, + SessionTransactionState.DEACTIVE, + SessionTransactionState.PREPARED, + ), + SessionTransactionState.CLOSED, + ) def rollback(self, _capture_exception=False, _to_root=False): - self._assert_active(prepared_ok=True, rollback_ok=True) stx = self.session._transaction if stx is not self: @@ -858,26 +876,29 @@ class SessionTransaction(TransactionalContext): boundary = self rollback_err = None - if self._state in (ACTIVE, PREPARED): + if self._state in ( + SessionTransactionState.ACTIVE, + SessionTransactionState.PREPARED, + ): for transaction in self._iterate_self_and_parents(): if transaction._parent is None or transaction.nested: try: for t in set(transaction._connections.values()): t[1].rollback() - transaction._state = DEACTIVE + transaction._state = SessionTransactionState.DEACTIVE self.session.dispatch.after_rollback(self.session) except: rollback_err = sys.exc_info() finally: - transaction._state = DEACTIVE + transaction._state = SessionTransactionState.DEACTIVE transaction._restore_snapshot( dirty_only=transaction.nested ) boundary = transaction break else: - transaction._state = DEACTIVE + transaction._state = SessionTransactionState.DEACTIVE sess = self.session @@ -892,7 +913,8 @@ class SessionTransaction(TransactionalContext): ) boundary._restore_snapshot(dirty_only=boundary.nested) - self.close() + with self._expect_state(SessionTransactionState.CLOSED): + self.close() if self._parent and _capture_exception: self._parent._rollback_exception = sys.exc_info()[1] @@ -906,6 +928,9 @@ class SessionTransaction(TransactionalContext): return self._parent.rollback(_to_root=True) return self._parent + @_StateChange.declare_states( + _StateChangeStates.ANY, SessionTransactionState.CLOSED + ) def close(self, invalidate=False): if self.nested: self.session._nested_transaction = ( @@ -925,20 +950,22 @@ class SessionTransaction(TransactionalContext): if autoclose: connection.close() - self._state = CLOSED - self.session.dispatch.after_transaction_end(self.session, self) + self._state = SessionTransactionState.CLOSED + sess = self.session self.session = None self._connections = None + sess.dispatch.after_transaction_end(sess, self) + def _get_subject(self): return self.session def _transaction_is_active(self): - return self._state is ACTIVE + return self._state is SessionTransactionState.ACTIVE def _transaction_is_closed(self): - return self._state is CLOSED + return self._state is SessionTransactionState.CLOSED def _rollback_can_be_called(self): return self._state not in (COMMITTED, CLOSED) diff --git a/lib/sqlalchemy/orm/state_changes.py b/lib/sqlalchemy/orm/state_changes.py new file mode 100644 index 0000000000..7d2c3e0566 --- /dev/null +++ b/lib/sqlalchemy/orm/state_changes.py @@ -0,0 +1,179 @@ +# orm/state_changes.py +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +"""State tracking utilities used by :class:`_orm.Session`. + +""" + +import contextlib +from enum import Enum +from typing import Any +from typing import Callable +from typing import Optional +from typing import Tuple +from typing import Union + +from .. import exc as sa_exc +from .. import util +from ..util.typing import Literal + + +class _StateChangeState(Enum): + pass + + +class _StateChangeStates(_StateChangeState): + ANY = 1 + NO_CHANGE = 2 + CHANGE_IN_PROGRESS = 3 + + +class _StateChange: + """Supplies state assertion decorators. + + The current use case is for the :class:`_orm.SessionTransaction` class. The + :class:`_StateChange` class itself is agnostic of the + :class:`_orm.SessionTransaction` class so could in theory be generalized + for other systems as well. + + """ + + _next_state: _StateChangeState = _StateChangeStates.ANY + _state: _StateChangeState = _StateChangeStates.NO_CHANGE + _current_fn: Optional[Callable] = None + + def _raise_for_prerequisite_state(self, operation_name, state): + raise sa_exc.IllegalStateChangeError( + f"Can't run operation '{operation_name}()' when Session " + f"is in state {state!r}" + ) + + @classmethod + def declare_states( + cls, + prerequisite_states: Union[ + Literal[_StateChangeStates.ANY], Tuple[_StateChangeState, ...] + ], + moves_to: _StateChangeState, + ) -> Callable[..., Any]: + """Method decorator declaring valid states. + + :param prerequisite_states: sequence of acceptable prerequisite + states. Can be the single constant _State.ANY to indicate no + prerequisite state + + :param moves_to: the expected state at the end of the method, assuming + no exceptions raised. Can be the constant _State.NO_CHANGE to + indicate state should not change at the end of the method. + + """ + assert prerequisite_states, "no prequisite states sent" + has_prerequisite_states = ( + prerequisite_states is not _StateChangeStates.ANY + ) + + expect_state_change = moves_to is not _StateChangeStates.NO_CHANGE + + @util.decorator + def _go(fn, self, *arg, **kw): + + current_state = self._state + + if ( + has_prerequisite_states + and current_state not in prerequisite_states + ): + self._raise_for_prerequisite_state(fn.__name__, current_state) + + next_state = self._next_state + existing_fn = self._current_fn + expect_state = moves_to if expect_state_change else current_state + + if ( + # destination states are restricted + next_state is not _StateChangeStates.ANY + # method seeks to change state + and expect_state_change + # destination state incorrect + and next_state is not expect_state + ): + if existing_fn and next_state in ( + _StateChangeStates.NO_CHANGE, + _StateChangeStates.CHANGE_IN_PROGRESS, + ): + raise sa_exc.IllegalStateChangeError( + f"Method '{fn.__name__}()' can't be called here; " + f"method '{existing_fn.__name__}()' is already " + f"in progress and this would cause an unexpected " + f"state change to {moves_to!r}" + ) + else: + raise sa_exc.IllegalStateChangeError( + f"Cant run operation '{fn.__name__}()' here; " + f"will move to state {moves_to!r} where we are " + f"expecting {next_state!r}" + ) + + self._current_fn = fn + self._next_state = _StateChangeStates.CHANGE_IN_PROGRESS + try: + ret_value = fn(self, *arg, **kw) + except: + raise + else: + if self._state is expect_state: + return ret_value + + if self._state is current_state: + raise sa_exc.IllegalStateChangeError( + f"Method '{fn.__name__}()' failed to " + "change state " + f"to {moves_to!r} as expected" + ) + elif existing_fn: + raise sa_exc.IllegalStateChangeError( + f"While method '{existing_fn.__name__}()' was " + "running, " + f"method '{fn.__name__}()' caused an " + "unexpected " + f"state change to {self._state!r}" + ) + else: + raise sa_exc.IllegalStateChangeError( + f"Method '{fn.__name__}()' caused an unexpected " + f"state change to {self._state!r}" + ) + + finally: + self._next_state = next_state + self._current_fn = existing_fn + + return _go + + @contextlib.contextmanager + def _expect_state(self, expected: _StateChangeState): + """called within a method that changes states. + + method must also use the ``@declare_states()`` decorator. + + """ + assert self._next_state is _StateChangeStates.CHANGE_IN_PROGRESS, ( + "Unexpected call to _expect_state outside of " + "state-changing method" + ) + + self._next_state = expected + try: + yield + except: + raise + else: + if self._state is not expected: + raise sa_exc.IllegalStateChangeError( + f"Unexpected state change to {self._state!r}" + ) + finally: + self._next_state = _StateChangeStates.CHANGE_IN_PROGRESS diff --git a/test/base/test_except.py b/test/base/test_except.py index 0bde988b79..e73160cd85 100644 --- a/test/base/test_except.py +++ b/test/base/test_except.py @@ -439,6 +439,7 @@ ALL_EXC = [ sa_exceptions.InvalidatePoolError, sa_exceptions.TimeoutError, sa_exceptions.InvalidRequestError, + sa_exceptions.IllegalStateChangeError, sa_exceptions.NoInspectionAvailable, sa_exceptions.PendingRollbackError, sa_exceptions.ResourceClosedError, diff --git a/test/orm/test_session_state_change.py b/test/orm/test_session_state_change.py new file mode 100644 index 0000000000..e2635abc22 --- /dev/null +++ b/test/orm/test_session_state_change.py @@ -0,0 +1,346 @@ +from sqlalchemy import exc as sa_exc +from sqlalchemy.orm import state_changes +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message +from sqlalchemy.testing import fixtures + + +class StateTestChange(state_changes._StateChangeState): + a = 1 + b = 2 + c = 3 + + +class StateMachineTest(fixtures.TestBase): + def test_single_change(self): + """test single method that declares and invokes a state change""" + _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE + + class Machine(state_changes._StateChange): + @state_changes._StateChange.declare_states( + (StateTestChange.a, _NO_CHANGE), StateTestChange.b + ) + def move_to_b(self): + self._state = StateTestChange.b + + m = Machine() + eq_(m._state, _NO_CHANGE) + m.move_to_b() + eq_(m._state, StateTestChange.b) + + def test_single_incorrect_change(self): + """test single method that declares a state change but changes to the + wrong state.""" + _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE + + class Machine(state_changes._StateChange): + @state_changes._StateChange.declare_states( + (StateTestChange.a, _NO_CHANGE), StateTestChange.b + ) + def move_to_b(self): + self._state = StateTestChange.c + + m = Machine() + eq_(m._state, _NO_CHANGE) + with expect_raises_message( + sa_exc.IllegalStateChangeError, + r"Method 'move_to_b\(\)' " + r"caused an unexpected state change to ", + ): + m.move_to_b() + + def test_single_failed_to_change(self): + """test single method that declares a state change but didn't do + the change.""" + _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE + + class Machine(state_changes._StateChange): + @state_changes._StateChange.declare_states( + (StateTestChange.a, _NO_CHANGE), StateTestChange.b + ) + def move_to_b(self): + pass + + m = Machine() + eq_(m._state, _NO_CHANGE) + with expect_raises_message( + sa_exc.IllegalStateChangeError, + r"Method 'move_to_b\(\)' failed to change state " + "to as " + "expected", + ): + m.move_to_b() + + def test_change_from_sub_method_with_declaration(self): + """test successful state change by one method calling another that + does the change. + + """ + _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE + + class Machine(state_changes._StateChange): + @state_changes._StateChange.declare_states( + (StateTestChange.a, _NO_CHANGE), StateTestChange.b + ) + def _inner_move_to_b(self): + self._state = StateTestChange.b + + @state_changes._StateChange.declare_states( + (StateTestChange.a, _NO_CHANGE), StateTestChange.b + ) + def move_to_b(self): + with self._expect_state(StateTestChange.b): + self._inner_move_to_b() + + m = Machine() + eq_(m._state, _NO_CHANGE) + m.move_to_b() + eq_(m._state, StateTestChange.b) + + def test_method_and_sub_method_no_change(self): + """test methods that declare the state should not change""" + _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE + + class Machine(state_changes._StateChange): + @state_changes._StateChange.declare_states( + (StateTestChange.a,), _NO_CHANGE + ) + def _inner_do_nothing(self): + pass + + @state_changes._StateChange.declare_states( + (StateTestChange.a,), _NO_CHANGE + ) + def do_nothing(self): + self._inner_do_nothing() + + m = Machine() + eq_(m._state, _NO_CHANGE) + m._state = StateTestChange.a + m.do_nothing() + eq_(m._state, StateTestChange.a) + + def test_method_w_no_change_illegal_inner_change(self): + _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE + + class Machine(state_changes._StateChange): + @state_changes._StateChange.declare_states( + (StateTestChange.a, _NO_CHANGE), StateTestChange.c + ) + def _inner_move_to_c(self): + self._state = StateTestChange.c + + @state_changes._StateChange.declare_states( + (StateTestChange.a,), _NO_CHANGE + ) + def do_nothing(self): + self._inner_move_to_c() + + m = Machine() + eq_(m._state, _NO_CHANGE) + m._state = StateTestChange.a + + with expect_raises_message( + sa_exc.IllegalStateChangeError, + r"Method '_inner_move_to_c\(\)' can't be called here; " + r"method 'do_nothing\(\)' is already in progress and this " + r"would cause an unexpected state change to " + "", + ): + m.do_nothing() + eq_(m._state, StateTestChange.a) + + def test_change_from_method_sub_w_no_change(self): + """test methods that declare the state should not change""" + _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE + + class Machine(state_changes._StateChange): + @state_changes._StateChange.declare_states( + (StateTestChange.a,), _NO_CHANGE + ) + def _inner_do_nothing(self): + pass + + @state_changes._StateChange.declare_states( + (StateTestChange.a,), StateTestChange.b + ) + def move_to_b(self): + self._inner_do_nothing() + self._state = StateTestChange.b + + m = Machine() + eq_(m._state, _NO_CHANGE) + m._state = StateTestChange.a + m.move_to_b() + eq_(m._state, StateTestChange.b) + + def test_invalid_change_from_declared_sub_method_with_declaration(self): + """A method uses _expect_state() to call a sub-method, which must + declare that state as its destination if no exceptions are raised. + + """ + _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE + + class Machine(state_changes._StateChange): + # method declares StateTestChange.c so can't be called under + # expect_state(StateTestChange.b) + @state_changes._StateChange.declare_states( + (StateTestChange.a, _NO_CHANGE), StateTestChange.c + ) + def _inner_move_to_c(self): + self._state = StateTestChange.c + + @state_changes._StateChange.declare_states( + (StateTestChange.a, _NO_CHANGE), StateTestChange.b + ) + def move_to_b(self): + with self._expect_state(StateTestChange.b): + self._inner_move_to_c() + + m = Machine() + eq_(m._state, _NO_CHANGE) + with expect_raises_message( + sa_exc.IllegalStateChangeError, + r"Cant run operation '_inner_move_to_c\(\)' here; will move " + r"to state where we are " + "expecting ", + ): + m.move_to_b() + + def test_invalid_change_from_invalid_sub_method_with_declaration(self): + """A method uses _expect_state() to call a sub-method, which must + declare that state as its destination if no exceptions are raised. + + Test an error is raised if the sub-method doesn't change to the + correct state. + + """ + _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE + + class Machine(state_changes._StateChange): + # method declares StateTestChange.b, but is doing the wrong + # change, so should fail under expect_state(StateTestChange.b) + @state_changes._StateChange.declare_states( + (StateTestChange.a, _NO_CHANGE), StateTestChange.b + ) + def _inner_move_to_c(self): + self._state = StateTestChange.c + + @state_changes._StateChange.declare_states( + (StateTestChange.a, _NO_CHANGE), StateTestChange.b + ) + def move_to_b(self): + with self._expect_state(StateTestChange.b): + self._inner_move_to_c() + + m = Machine() + eq_(m._state, _NO_CHANGE) + with expect_raises_message( + sa_exc.IllegalStateChangeError, + r"While method 'move_to_b\(\)' was running, method " + r"'_inner_move_to_c\(\)' caused an unexpected state change " + "to ", + ): + m.move_to_b() + + def test_invalid_prereq_state(self): + _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE + + class Machine(state_changes._StateChange): + @state_changes._StateChange.declare_states( + (StateTestChange.a, _NO_CHANGE), StateTestChange.b + ) + def move_to_b(self): + self._state = StateTestChange.b + + @state_changes._StateChange.declare_states( + (StateTestChange.c,), "d" + ) + def move_to_d(self): + self._state = "d" + + m = Machine() + eq_(m._state, _NO_CHANGE) + m.move_to_b() + eq_(m._state, StateTestChange.b) + with expect_raises_message( + sa_exc.IllegalStateChangeError, + r"Can't run operation 'move_to_d\(\)' when " + "Session is in state ", + ): + m.move_to_d() + + def test_declare_only(self): + _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE + + class Machine(state_changes._StateChange): + @state_changes._StateChange.declare_states( + state_changes._StateChangeStates.ANY, StateTestChange.b + ) + def _inner_move_to_b(self): + self._state = StateTestChange.b + + def move_to_b(self): + with self._expect_state(StateTestChange.b): + self._move_to_b() + + m = Machine() + eq_(m._state, _NO_CHANGE) + with expect_raises_message( + AssertionError, + "Unexpected call to _expect_state outside of " + "state-changing method", + ): + m.move_to_b() + + def test_sibling_calls_maintain_correct_state(self): + _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE + + class Machine(state_changes._StateChange): + @state_changes._StateChange.declare_states( + state_changes._StateChangeStates.ANY, StateTestChange.c + ) + def move_to_c(self): + self._state = StateTestChange.c + + @state_changes._StateChange.declare_states( + state_changes._StateChangeStates.ANY, _NO_CHANGE + ) + def do_nothing(self): + pass + + m = Machine() + m.do_nothing() + eq_(m._state, _NO_CHANGE) + m.move_to_c() + eq_(m._state, StateTestChange.c) + + def test_change_from_sub_method_requires_declaration(self): + """A method can't call another state-changing method without using + _expect_state() to allow the state change to occur. + + """ + _NO_CHANGE = state_changes._StateChangeStates.NO_CHANGE + + class Machine(state_changes._StateChange): + @state_changes._StateChange.declare_states( + (StateTestChange.a, _NO_CHANGE), StateTestChange.b + ) + def _inner_move_to_b(self): + self._state = StateTestChange.b + + @state_changes._StateChange.declare_states( + (StateTestChange.a, _NO_CHANGE), StateTestChange.b + ) + def move_to_b(self): + self._inner_move_to_b() + + m = Machine() + + with expect_raises_message( + sa_exc.IllegalStateChangeError, + r"Method '_inner_move_to_b\(\)' can't be called here; " + r"method 'move_to_b\(\)' is already in progress and this would " + r"cause an unexpected state change to ", + ): + m.move_to_b()