the left."""
def __contains__(self, item):
for _call in self:
- if len(item) != len(_call):
- continue
+ assert len(item) == len(_call)
if all([
expected == actual
for expected, actual in zip(item, _call)
def __exit__(self, *args):
"""Unpatch the dict."""
- self._unpatch_dict()
+ if self._original is not None:
+ self._unpatch_dict()
return False
self.__dict__['__code__'] = code_mock
async def _execute_mock_call(self, /, *args, **kwargs):
- # This is nearly just like super(), except for sepcial handling
+ # This is nearly just like super(), except for special handling
# of coroutines
_call = self.call_args
return tuple.__getattribute__(self, attr)
- def count(self, /, *args, **kwargs):
- return self.__getattr__('count')(*args, **kwargs)
-
- def index(self, /, *args, **kwargs):
- return self.__getattr__('index')(*args, **kwargs)
-
def _get_call_arguments(self):
if len(self) == 2:
args, kwargs = self
code_mock.co_flags = inspect.CO_ITERABLE_COROUTINE
self.__dict__['__code__'] = code_mock
- def __aiter__(self):
- return self
-
async def __anext__(self):
try:
return next(self.iterator)
class AsyncClass:
- def __init__(self):
- pass
- async def async_method(self):
- pass
- def normal_method(self):
- pass
+ def __init__(self): pass
+ async def async_method(self): pass
+ def normal_method(self): pass
@classmethod
- async def async_class_method(cls):
- pass
+ async def async_class_method(cls): pass
@staticmethod
- async def async_static_method():
- pass
+ async def async_static_method(): pass
class AwaitableClass:
- def __await__(self):
- yield
+ def __await__(self): yield
-async def async_func():
- pass
+async def async_func(): pass
-async def async_func_args(a, b, *, c):
- pass
+async def async_func_args(a, b, *, c): pass
-def normal_func():
- pass
+def normal_func(): pass
class NormalClass(object):
- def a(self):
- pass
+ def a(self): pass
async_foo_name = f'{__name__}.AsyncClass'
class AsyncArguments(IsolatedAsyncioTestCase):
async def test_add_return_value(self):
- async def addition(self, var):
- return var + 1
+ async def addition(self, var): pass
mock = AsyncMock(addition, return_value=10)
output = await mock(5)
self.assertEqual(output, 10)
async def test_add_side_effect_exception(self):
- async def addition(var):
- return var + 1
+ async def addition(var): pass
mock = AsyncMock(addition, side_effect=Exception('err'))
with self.assertRaises(Exception):
await mock(5)
class AsyncContextManagerTest(unittest.TestCase):
class WithAsyncContextManager:
- async def __aenter__(self, *args, **kwargs):
- return self
+ async def __aenter__(self, *args, **kwargs): pass
- async def __aexit__(self, *args, **kwargs):
- pass
+ async def __aexit__(self, *args, **kwargs): pass
class WithSyncContextManager:
- def __enter__(self, *args, **kwargs):
- return self
+ def __enter__(self, *args, **kwargs): pass
- def __exit__(self, *args, **kwargs):
- pass
+ def __exit__(self, *args, **kwargs): pass
class ProductionCode:
# Example real-world(ish) code
def __init__(self):
self.items = ["foo", "NormalFoo", "baz"]
- def __aiter__(self):
- return self
-
- async def __anext__(self):
- try:
- return self.items.pop()
- except IndexError:
- pass
+ def __aiter__(self): pass
- raise StopAsyncIteration
+ async def __anext__(self): pass
def test_aiter_set_return_value(self):
mock_iter = AsyncMock(name="tester")