from contextlib import asynccontextmanager
from enum import Enum
from functools import wraps
+from inspect import iscoroutinefunction
import logging
import socket
from ssl import SSLContext
:param required_state: The `Runstate` required to invoke this method.
:raise StateError: When the required `Runstate` is not met.
"""
+ def _check(proto: 'AsyncProtocol[Any]') -> None:
+ name = type(proto).__name__
+ if proto.runstate == required_state:
+ return
+
+ if proto.runstate == Runstate.CONNECTING:
+ emsg = f"{name} is currently connecting."
+ elif proto.runstate == Runstate.DISCONNECTING:
+ emsg = (f"{name} is disconnecting."
+ " Call disconnect() to return to IDLE state.")
+ elif proto.runstate == Runstate.RUNNING:
+ emsg = f"{name} is already connected and running."
+ elif proto.runstate == Runstate.IDLE:
+ emsg = f"{name} is disconnected and idle."
+ else:
+ assert False
+
+ raise StateError(emsg, proto.runstate, required_state)
+
def _decorator(func: F) -> F:
# _decorator is the decorator that is built by calling the
# require() decorator factory; e.g.:
@wraps(func)
def _wrapper(proto: 'AsyncProtocol[Any]',
*args: Any, **kwargs: Any) -> Any:
- # _wrapper is the function that gets executed prior to the
- # decorated method.
-
- name = type(proto).__name__
-
- if proto.runstate != required_state:
- if proto.runstate == Runstate.CONNECTING:
- emsg = f"{name} is currently connecting."
- elif proto.runstate == Runstate.DISCONNECTING:
- emsg = (f"{name} is disconnecting."
- " Call disconnect() to return to IDLE state.")
- elif proto.runstate == Runstate.RUNNING:
- emsg = f"{name} is already connected and running."
- elif proto.runstate == Runstate.IDLE:
- emsg = f"{name} is disconnected and idle."
- else:
- assert False
- raise StateError(emsg, proto.runstate, required_state)
- # No StateError, so call the wrapped method.
+ _check(proto)
return func(proto, *args, **kwargs)
- # Return the decorated method;
- # Transforming Func to Decorated[Func].
+ @wraps(func)
+ async def _async_wrapper(proto: 'AsyncProtocol[Any]',
+ *args: Any, **kwargs: Any) -> Any:
+ _check(proto)
+ return await func(proto, *args, **kwargs)
+
+ # Return the decorated method; F => Decorated[F]
+ # Use an async version when applicable, which
+ # preserves async signature generation in sphinx.
+ if iscoroutinefunction(func):
+ return cast(F, _async_wrapper)
return cast(F, _wrapper)
# Return the decorator instance from the decorator factory. Phew!