def __call__(self, f):
if isinstance(f, type):
return self.decorate_class(f)
+ if inspect.iscoroutinefunction(f):
+ return self.decorate_async_callable(f)
+ return self.decorate_callable(f)
+
+
+ def decorate_callable(self, f):
@wraps(f)
def _inner(*args, **kw):
self._patch_dict()
return _inner
+ def decorate_async_callable(self, f):
+ @wraps(f)
+ async def _inner(*args, **kw):
+ self._patch_dict()
+ try:
+ return await f(*args, **kw)
+ finally:
+ self._unpatch_dict()
+
+ return _inner
+
+
def decorate_class(self, klass):
for attr in dir(klass):
attr_value = getattr(klass, attr)
run(test_async())
+ def test_patch_dict_async_def(self):
+ foo = {'a': 'a'}
+ @patch.dict(foo, {'a': 'b'})
+ async def test_async():
+ self.assertEqual(foo['a'], 'b')
+
+ self.assertTrue(iscoroutinefunction(test_async))
+ run(test_async())
+
+ def test_patch_dict_async_def_context(self):
+ foo = {'a': 'a'}
+ async def test_async():
+ with patch.dict(foo, {'a': 'b'}):
+ self.assertEqual(foo['a'], 'b')
+
+ run(test_async())
+
class AsyncMockTest(unittest.TestCase):
def test_iscoroutinefunction_default(self):