]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-110686: Test pattern matching with `runtime_checkable` protocols (#110687)
authorNikita Sobolev <mail@sobolevn.me>
Sun, 10 Dec 2023 15:21:20 +0000 (18:21 +0300)
committerGitHub <noreply@github.com>
Sun, 10 Dec 2023 15:21:20 +0000 (07:21 -0800)
Lib/test/test_patma.py

index dedbc828784184671707657b81384171690aa13b..298e78ccee387575f9eb71b475176fc217f22519 100644 (file)
@@ -2760,6 +2760,132 @@ class TestPatma(unittest.TestCase):
         self.assertEqual(y, 1)
         self.assertIs(z, x)
 
+    def test_patma_runtime_checkable_protocol(self):
+        # Runtime-checkable protocol
+        from typing import Protocol, runtime_checkable
+
+        @runtime_checkable
+        class P(Protocol):
+            x: int
+            y: int
+
+        class A:
+            def __init__(self, x: int, y: int):
+                self.x = x
+                self.y = y
+
+        class B(A): ...
+
+        for cls in (A, B):
+            with self.subTest(cls=cls.__name__):
+                inst = cls(1, 2)
+                w = 0
+                match inst:
+                    case P() as p:
+                        self.assertIsInstance(p, cls)
+                        self.assertEqual(p.x, 1)
+                        self.assertEqual(p.y, 2)
+                        w = 1
+                self.assertEqual(w, 1)
+
+                q = 0
+                match inst:
+                    case P(x=x, y=y):
+                        self.assertEqual(x, 1)
+                        self.assertEqual(y, 2)
+                        q = 1
+                self.assertEqual(q, 1)
+
+
+    def test_patma_generic_protocol(self):
+        # Runtime-checkable generic protocol
+        from typing import Generic, TypeVar, Protocol, runtime_checkable
+
+        T = TypeVar('T')  # not using PEP695 to be able to backport changes
+
+        @runtime_checkable
+        class P(Protocol[T]):
+            a: T
+            b: T
+
+        class A:
+            def __init__(self, x: int, y: int):
+                self.x = x
+                self.y = y
+
+        class G(Generic[T]):
+            def __init__(self, x: T, y: T):
+                self.x = x
+                self.y = y
+
+        for cls in (A, G):
+            with self.subTest(cls=cls.__name__):
+                inst = cls(1, 2)
+                w = 0
+                match inst:
+                    case P():
+                        w = 1
+                self.assertEqual(w, 0)
+
+    def test_patma_protocol_with_match_args(self):
+        # Runtime-checkable protocol with `__match_args__`
+        from typing import Protocol, runtime_checkable
+
+        # Used to fail before
+        # https://github.com/python/cpython/issues/110682
+        @runtime_checkable
+        class P(Protocol):
+            __match_args__ = ('x', 'y')
+            x: int
+            y: int
+
+        class A:
+            def __init__(self, x: int, y: int):
+                self.x = x
+                self.y = y
+
+        class B(A): ...
+
+        for cls in (A, B):
+            with self.subTest(cls=cls.__name__):
+                inst = cls(1, 2)
+                w = 0
+                match inst:
+                    case P() as p:
+                        self.assertIsInstance(p, cls)
+                        self.assertEqual(p.x, 1)
+                        self.assertEqual(p.y, 2)
+                        w = 1
+                self.assertEqual(w, 1)
+
+                q = 0
+                match inst:
+                    case P(x=x, y=y):
+                        self.assertEqual(x, 1)
+                        self.assertEqual(y, 2)
+                        q = 1
+                self.assertEqual(q, 1)
+
+                j = 0
+                match inst:
+                    case P(x=1, y=2):
+                        j = 1
+                self.assertEqual(j, 1)
+
+                g = 0
+                match inst:
+                    case P(x, y):
+                        self.assertEqual(x, 1)
+                        self.assertEqual(y, 2)
+                        g = 1
+                self.assertEqual(g, 1)
+
+                h = 0
+                match inst:
+                    case P(1, 2):
+                        h = 1
+                self.assertEqual(h, 1)
+
 
 class TestSyntaxErrors(unittest.TestCase):
 
@@ -3198,6 +3324,35 @@ class TestTypeErrors(unittest.TestCase):
                     w = 0
         self.assertIsNone(w)
 
+    def test_regular_protocol(self):
+        from typing import Protocol
+        class P(Protocol): ...
+        msg = (
+            'Instance and class checks can only be used '
+            'with @runtime_checkable protocols'
+        )
+        w = None
+        with self.assertRaisesRegex(TypeError, msg):
+            match 1:
+                case P():
+                    w = 0
+        self.assertIsNone(w)
+
+    def test_positional_patterns_with_regular_protocol(self):
+        from typing import Protocol
+        class P(Protocol):
+            x: int  # no `__match_args__`
+            y: int
+        class A:
+            x = 1
+            y = 2
+        w = None
+        with self.assertRaises(TypeError):
+            match A():
+                case P(x, y):
+                    w = 0
+        self.assertIsNone(w)
+
 
 class TestValueErrors(unittest.TestCase):