]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
GH-106135: Add more edge-"cases" to test_patma (GH-106271)
authorNikita Sobolev <mail@sobolevn.me>
Fri, 30 Jun 2023 19:39:50 +0000 (22:39 +0300)
committerGitHub <noreply@github.com>
Fri, 30 Jun 2023 19:39:50 +0000 (19:39 +0000)
Lib/test/test_patma.py

index 3dbd19dfffd31861ac20a3b5e0dcd7d5b974b179..dedbc828784184671707657b81384171690aa13b 100644 (file)
@@ -2460,12 +2460,27 @@ class TestPatma(unittest.TestCase):
             def __eq__(self, other):
                 return True
         x = eq = Eq()
+        # None
         y = None
         match x:
             case None:
                 y = 0
         self.assertIs(x, eq)
         self.assertEqual(y, None)
+        # True
+        y = None
+        match x:
+            case True:
+                y = 0
+        self.assertIs(x, eq)
+        self.assertEqual(y, None)
+        # False
+        y = None
+        match x:
+            case False:
+                y = 0
+        self.assertIs(x, eq)
+        self.assertEqual(y, None)
 
     def test_patma_233(self):
         x = False
@@ -2668,6 +2683,83 @@ class TestPatma(unittest.TestCase):
         setattr(c, "__attr", "spam")  # setattr is needed because we're in a class scope
         self.assertEqual(Outer().f(c), "spam")
 
+    def test_patma_250(self):
+        def f(x):
+            match x:
+                case {"foo": y} if y >= 0:
+                    return True
+                case {"foo": y} if y < 0:
+                    return False
+
+        self.assertIs(f({"foo": 1}), True)
+        self.assertIs(f({"foo": -1}), False)
+
+    def test_patma_251(self):
+        def f(v, x):
+            match v:
+                case x.attr if x.attr >= 0:
+                    return True
+                case x.attr if x.attr < 0:
+                    return False
+                case _:
+                    return None
+
+        class X:
+            def __init__(self, attr):
+                self.attr = attr
+
+        self.assertIs(f(1, X(1)), True)
+        self.assertIs(f(-1, X(-1)), False)
+        self.assertIs(f(1, X(-1)), None)
+
+    def test_patma_252(self):
+        # Side effects must be possible in guards:
+        effects = []
+        def lt(x, y):
+            effects.append((x, y))
+            return x < y
+
+        res = None
+        match {"foo": 1}:
+            case {"foo": x} if lt(x, 0):
+                res = 0
+            case {"foo": x} if lt(x, 1):
+                res = 1
+            case {"foo": x} if lt(x, 2):
+                res = 2
+
+        self.assertEqual(res, 2)
+        self.assertEqual(effects, [(1, 0), (1, 1), (1, 2)])
+
+    def test_patma_253(self):
+        def f(v):
+            match v:
+                case [x] | x:
+                    return x
+
+        self.assertEqual(f(1), 1)
+        self.assertEqual(f([1]), 1)
+
+    def test_patma_254(self):
+        def f(v):
+            match v:
+                case {"x": x} | x:
+                    return x
+
+        self.assertEqual(f(1), 1)
+        self.assertEqual(f({"x": 1}), 1)
+
+    def test_patma_255(self):
+        x = []
+        match x:
+            case [] as z if z.append(None):
+                y = 0
+            case [None]:
+                y = 1
+        self.assertEqual(x, [None])
+        self.assertEqual(y, 1)
+        self.assertIs(z, x)
+
 
 class TestSyntaxErrors(unittest.TestCase):
 
@@ -2885,6 +2977,37 @@ class TestSyntaxErrors(unittest.TestCase):
                 pass
         """)
 
+    def test_real_number_multiple_ops(self):
+        self.assert_syntax_error("""
+        match ...:
+            case 0 + 0j + 0:
+                pass
+        """)
+
+    def test_real_number_wrong_ops(self):
+        for op in ["*", "/", "@", "**", "%", "//"]:
+            with self.subTest(op=op):
+                self.assert_syntax_error(f"""
+                match ...:
+                    case 0 {op} 0j:
+                        pass
+                """)
+                self.assert_syntax_error(f"""
+                match ...:
+                    case 0j {op} 0:
+                        pass
+                """)
+                self.assert_syntax_error(f"""
+                match ...:
+                    case -0j {op} 0:
+                        pass
+                """)
+                self.assert_syntax_error(f"""
+                match ...:
+                    case 0j {op} -0:
+                        pass
+                """)
+
     def test_wildcard_makes_remaining_patterns_unreachable_0(self):
         self.assert_syntax_error("""
         match ...:
@@ -3067,6 +3190,14 @@ class TestTypeErrors(unittest.TestCase):
         self.assertIs(y, None)
         self.assertIs(z, None)
 
+    def test_class_pattern_not_type(self):
+        w = None
+        with self.assertRaises(TypeError):
+            match 1:
+                case max(0, 1):
+                    w = 0
+        self.assertIsNone(w)
+
 
 class TestValueErrors(unittest.TestCase):