]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-45045: Optimize mapping patterns of structural pattern matching (GH-28043)
authorDong-hee Na <donghee.na@python.org>
Mon, 30 Aug 2021 10:02:32 +0000 (10:02 +0000)
committerGitHub <noreply@github.com>
Mon, 30 Aug 2021 10:02:32 +0000 (19:02 +0900)
Lib/test/test_patma.py
Python/ceval.c

index aa18e29e22548fa6414fdd45ca931c08678085b7..57d3b1ec701ca4c1072b827fb90cf16d33c951de 100644 (file)
@@ -2641,6 +2641,19 @@ class TestPatma(unittest.TestCase):
         self.assertEqual(f((False, range(-1, -11, -1), True)), alts[3])
         self.assertEqual(f((False, range(10, 20), True)), alts[4])
 
+    def test_patma_248(self):
+        class C(dict):
+            @staticmethod
+            def get(key, default=None):
+                return 'bar'
+
+        x = C({'foo': 'bar'})
+        match x:
+            case {'foo': bar}:
+                y = bar
+
+        self.assertEqual(y, 'bar')
+
 
 class TestSyntaxErrors(unittest.TestCase):
 
index 8aaa83b1b74bf4a23150d3d8a8c8071e3a886cae..bf95d50b6295823d9a30d69cabf217f949250d79 100644 (file)
@@ -841,12 +841,18 @@ match_keys(PyThreadState *tstate, PyObject *map, PyObject *keys)
     PyObject *seen = NULL;
     PyObject *dummy = NULL;
     PyObject *values = NULL;
+    PyObject *get_name = NULL;
+    PyObject *get = NULL;
     // We use the two argument form of map.get(key, default) for two reasons:
     // - Atomically check for a key and get its value without error handling.
     // - Don't cause key creation or resizing in dict subclasses like
     //   collections.defaultdict that define __missing__ (or similar).
     _Py_IDENTIFIER(get);
-    PyObject *get = _PyObject_GetAttrId(map, &PyId_get);
+    get_name = _PyUnicode_FromId(&PyId_get); // borrowed
+    if (get_name == NULL) {
+        return NULL;
+    }
+    int meth_found = _PyObject_GetMethod(map, get_name, &get);
     if (get == NULL) {
         goto fail;
     }
@@ -859,7 +865,7 @@ match_keys(PyThreadState *tstate, PyObject *map, PyObject *keys)
     if (dummy == NULL) {
         goto fail;
     }
-    values = PyList_New(0);
+    values = PyTuple_New(nkeys);
     if (values == NULL) {
         goto fail;
     }
@@ -873,7 +879,14 @@ match_keys(PyThreadState *tstate, PyObject *map, PyObject *keys)
             }
             goto fail;
         }
-        PyObject *value = PyObject_CallFunctionObjArgs(get, key, dummy, NULL);
+        PyObject *args[] = { map, key, dummy };
+        PyObject *value = NULL;
+        if (meth_found) {
+            value = PyObject_Vectorcall(get, args, 3, NULL);
+        }
+        else {
+            value = PyObject_Vectorcall(get, &args[1], 2, NULL);
+        }
         if (value == NULL) {
             goto fail;
         }
@@ -886,10 +899,8 @@ match_keys(PyThreadState *tstate, PyObject *map, PyObject *keys)
             values = Py_None;
             goto done;
         }
-        PyList_Append(values, value);
-        Py_DECREF(value);
+        PyTuple_SET_ITEM(values, i, value);
     }
-    Py_SETREF(values, PyList_AsTuple(values));
     // Success:
 done:
     Py_DECREF(get);