From 38650c98c6afadf25951fa128deea3cfefd82ef7 Mon Sep 17 00:00:00 2001 From: Elizabeth Uselton Date: Mon, 5 Aug 2019 00:51:24 -0700 Subject: [PATCH] bpo-37555: Ensure all assert methods using _call_matcher are actually passing calls --- Lib/unittest/mock.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py index 47592057e1b1..f9789e5e85a9 100644 --- a/Lib/unittest/mock.py +++ b/Lib/unittest/mock.py @@ -864,9 +864,9 @@ class NonCallableMock(Base): def _error_message(): msg = self._format_mock_failure_message(args, kwargs) return msg - expected = self._call_matcher((args, kwargs)) + expected = self._call_matcher(_Call((args, kwargs))) actual = self._call_matcher(self.call_args) - if expected != actual: + if actual != expected: cause = expected if isinstance(expected, Exception) else None raise AssertionError(_error_message()) from cause @@ -926,10 +926,10 @@ class NonCallableMock(Base): The assert passes if the mock has *ever* been called, unlike `assert_called_with` and `assert_called_once_with` that only pass if the call is the most recent one.""" - expected = self._call_matcher((args, kwargs)) + expected = self._call_matcher(_Call((args, kwargs), two=True)) + cause = expected if isinstance(expected, Exception) else None actual = [self._call_matcher(c) for c in self.call_args_list] - if expected not in actual: - cause = expected if isinstance(expected, Exception) else None + if cause or expected not in _AnyComparer(actual): expected_string = self._format_mock_call_signature(args, kwargs) raise AssertionError( '%s call not found' % expected_string @@ -982,6 +982,22 @@ class NonCallableMock(Base): return f"\n{prefix}: {safe_repr(self.mock_calls)}." +class _AnyComparer(list): + """A list which checks if it contains a call which may have an + argument of ANY, flipping the components of item and self from + their traditional locations so that ANY is guaranteed to be on + the left.""" + def __contains__(self, item): + for _call in self: + if len(item) != len(_call): + continue + if all([ + expected == actual + for expected, actual in zip(item, _call) + ]): + return True + return False + def _try_iter(obj): if obj is None: @@ -2133,9 +2149,9 @@ class AsyncMockMixin(Base): msg = self._format_mock_failure_message(args, kwargs, action='await') return msg - expected = self._call_matcher((args, kwargs)) + expected = self._call_matcher(_Call((args, kwargs), two=True)) actual = self._call_matcher(self.await_args) - if expected != actual: + if actual != expected: cause = expected if isinstance(expected, Exception) else None raise AssertionError(_error_message()) from cause @@ -2154,9 +2170,9 @@ class AsyncMockMixin(Base): """ Assert the mock has ever been awaited with the specified arguments. """ - expected = self._call_matcher((args, kwargs)) + expected = self._call_matcher(_Call((args, kwargs), two=True)) actual = [self._call_matcher(c) for c in self.await_args_list] - if expected not in actual: + if expected not in _AnyComparer(actual): cause = expected if isinstance(expected, Exception) else None expected_string = self._format_mock_call_signature(args, kwargs) raise AssertionError( -- 2.47.3