]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-37555: Ensure all assert methods using _call_matcher are actually passing calls
authorElizabeth Uselton <elizabeth.uselton@rover.com>
Mon, 5 Aug 2019 07:51:24 +0000 (00:51 -0700)
committerElizabeth Uselton <elizabeth.uselton@rover.com>
Mon, 5 Aug 2019 07:53:37 +0000 (00:53 -0700)
Lib/unittest/mock.py

index 47592057e1b1c03e4dc1b5b1743350f74c9fb8c1..f9789e5e85a97c80cd11230f1d90ef648ff4b8da 100644 (file)
@@ -864,9 +864,9 @@ class NonCallableMock(Base):
         def _error_message():
             msg = self._format_mock_failure_message(args, kwargs)
             return msg
-        expected = self._call_matcher((args, kwargs))
+        expected = self._call_matcher(_Call((args, kwargs)))
         actual = self._call_matcher(self.call_args)
-        if expected != actual:
+        if actual != expected:
             cause = expected if isinstance(expected, Exception) else None
             raise AssertionError(_error_message()) from cause
 
@@ -926,10 +926,10 @@ class NonCallableMock(Base):
         The assert passes if the mock has *ever* been called, unlike
         `assert_called_with` and `assert_called_once_with` that only pass if
         the call is the most recent one."""
-        expected = self._call_matcher((args, kwargs))
+        expected = self._call_matcher(_Call((args, kwargs), two=True))
+        cause = expected if isinstance(expected, Exception) else None
         actual = [self._call_matcher(c) for c in self.call_args_list]
-        if expected not in actual:
-            cause = expected if isinstance(expected, Exception) else None
+        if cause or expected not in _AnyComparer(actual):
             expected_string = self._format_mock_call_signature(args, kwargs)
             raise AssertionError(
                 '%s call not found' % expected_string
@@ -982,6 +982,22 @@ class NonCallableMock(Base):
         return f"\n{prefix}: {safe_repr(self.mock_calls)}."
 
 
+class _AnyComparer(list):
+    """A list which checks if it contains a call which may have an
+    argument of ANY, flipping the components of item and self from
+    their traditional locations so that ANY is guaranteed to be on
+    the left."""
+    def __contains__(self, item):
+        for _call in self:
+            if len(item) != len(_call):
+                continue
+            if all([
+                expected == actual
+                for expected, actual in zip(item, _call)
+            ]):
+                return True
+        return False
+
 
 def _try_iter(obj):
     if obj is None:
@@ -2133,9 +2149,9 @@ class AsyncMockMixin(Base):
             msg = self._format_mock_failure_message(args, kwargs, action='await')
             return msg
 
-        expected = self._call_matcher((args, kwargs))
+        expected = self._call_matcher(_Call((args, kwargs), two=True))
         actual = self._call_matcher(self.await_args)
-        if expected != actual:
+        if actual != expected:
             cause = expected if isinstance(expected, Exception) else None
             raise AssertionError(_error_message()) from cause
 
@@ -2154,9 +2170,9 @@ class AsyncMockMixin(Base):
         """
         Assert the mock has ever been awaited with the specified arguments.
         """
-        expected = self._call_matcher((args, kwargs))
+        expected = self._call_matcher(_Call((args, kwargs), two=True))
         actual = [self._call_matcher(c) for c in self.await_args_list]
-        if expected not in actual:
+        if expected not in _AnyComparer(actual):
             cause = expected if isinstance(expected, Exception) else None
             expected_string = self._format_mock_call_signature(args, kwargs)
             raise AssertionError(