]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-36144: Implement defaultdict union (GH-18729)
authorBrandt Bucher <brandtbucher@gmail.com>
Fri, 6 Mar 2020 17:24:08 +0000 (09:24 -0800)
committerGitHub <noreply@github.com>
Fri, 6 Mar 2020 17:24:08 +0000 (09:24 -0800)
For PEP 585 (this isn't in the PEP but is an obvious follow-up).

Doc/library/collections.rst
Lib/test/test_defaultdict.py
Misc/NEWS.d/next/Library/2020-02-29-15-54-08.bpo-36144.4GgTZs.rst [new file with mode: 0644]
Modules/_collectionsmodule.c

index 8dcf9451d72bfe48e469e2df31d7887945b8f980..f4a383c8ea57de51c674cda5c9dc6cef0859779d 100644 (file)
@@ -729,6 +729,10 @@ stack manipulations such as ``dup``, ``drop``, ``swap``, ``over``, ``pick``,
         initialized from the first argument to the constructor, if present, or to
         ``None``, if absent.
 
+    .. versionchanged:: 3.9
+       Added merge (``|``) and update (``|=``) operators, specified in
+       :pep:`584`.
+
 
 :class:`defaultdict` Examples
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
index b9f1fb9f23d39d27b92e88df16c3ab439ca60db0..b48c649fce6ba1cb5e4df20d48cd3b39de459c86 100644 (file)
@@ -183,5 +183,43 @@ class TestDefaultDict(unittest.TestCase):
             o = pickle.loads(s)
             self.assertEqual(d, o)
 
+    def test_union(self):
+        i = defaultdict(int, {1: 1, 2: 2})
+        s = defaultdict(str, {0: "zero", 1: "one"})
+
+        i_s = i | s
+        self.assertIs(i_s.default_factory, int)
+        self.assertDictEqual(i_s, {1: "one", 2: 2, 0: "zero"})
+        self.assertEqual(list(i_s), [1, 2, 0])
+
+        s_i = s | i
+        self.assertIs(s_i.default_factory, str)
+        self.assertDictEqual(s_i, {0: "zero", 1: 1, 2: 2})
+        self.assertEqual(list(s_i), [0, 1, 2])
+
+        i_ds = i | dict(s)
+        self.assertIs(i_ds.default_factory, int)
+        self.assertDictEqual(i_ds, {1: "one", 2: 2, 0: "zero"})
+        self.assertEqual(list(i_ds), [1, 2, 0])
+
+        ds_i = dict(s) | i
+        self.assertIs(ds_i.default_factory, int)
+        self.assertDictEqual(ds_i, {0: "zero", 1: 1, 2: 2})
+        self.assertEqual(list(ds_i), [0, 1, 2])
+
+        with self.assertRaises(TypeError):
+            i | list(s.items())
+        with self.assertRaises(TypeError):
+            list(s.items()) | i
+
+        # We inherit a fine |= from dict, so just a few sanity checks here:
+        i |= list(s.items())
+        self.assertIs(i.default_factory, int)
+        self.assertDictEqual(i, {1: "one", 2: 2, 0: "zero"})
+        self.assertEqual(list(i), [1, 2, 0])
+
+        with self.assertRaises(TypeError):
+            i |= None
+
 if __name__ == "__main__":
     unittest.main()
diff --git a/Misc/NEWS.d/next/Library/2020-02-29-15-54-08.bpo-36144.4GgTZs.rst b/Misc/NEWS.d/next/Library/2020-02-29-15-54-08.bpo-36144.4GgTZs.rst
new file mode 100644 (file)
index 0000000..416d5ac
--- /dev/null
@@ -0,0 +1 @@
+:class:`collections.defaultdict` now implements ``|`` (:pep:`584`).
index 4d5d874b44da1669b194269c200ddcc8d89e6194..d0a381deabf5d8c941fa382e2048db4eae2d9759 100644 (file)
@@ -1990,6 +1990,13 @@ defdict_missing(defdictobject *dd, PyObject *key)
     return value;
 }
 
+static inline PyObject*
+new_defdict(defdictobject *dd, PyObject *arg)
+{
+    return PyObject_CallFunctionObjArgs((PyObject*)Py_TYPE(dd),
+        dd->default_factory ? dd->default_factory : Py_None, arg, NULL);
+}
+
 PyDoc_STRVAR(defdict_copy_doc, "D.copy() -> a shallow copy of D.");
 
 static PyObject *
@@ -1999,11 +2006,7 @@ defdict_copy(defdictobject *dd, PyObject *Py_UNUSED(ignored))
        whose class constructor has the same signature.  Subclasses that
        define a different constructor signature must override copy().
     */
-
-    if (dd->default_factory == NULL)
-        return PyObject_CallFunctionObjArgs((PyObject*)Py_TYPE(dd), Py_None, dd, NULL);
-    return PyObject_CallFunctionObjArgs((PyObject*)Py_TYPE(dd),
-                                        dd->default_factory, dd, NULL);
+    return new_defdict(dd, (PyObject*)dd);
 }
 
 static PyObject *
@@ -2127,6 +2130,42 @@ defdict_repr(defdictobject *dd)
     return result;
 }
 
+static PyObject*
+defdict_or(PyObject* left, PyObject* right)
+{
+    int left_is_self = PyObject_IsInstance(left, (PyObject*)&defdict_type);
+    if (left_is_self < 0) {
+        return NULL;
+    }
+    PyObject *self, *other;
+    if (left_is_self) {
+        self = left;
+        other = right;
+    }
+    else {
+        self = right;
+        other = left;
+    }
+    if (!PyDict_Check(other)) {
+        Py_RETURN_NOTIMPLEMENTED;
+    }
+    // Like copy(), this calls the object's class.
+    // Override __or__/__ror__ for subclasses with different constructors.
+    PyObject *new = new_defdict((defdictobject*)self, left);
+    if (!new) {
+        return NULL;
+    }
+    if (PyDict_Update(new, right)) {
+        Py_DECREF(new);
+        return NULL;
+    }
+    return new;
+}
+
+static PyNumberMethods defdict_as_number = {
+    .nb_or = defdict_or,
+};
+
 static int
 defdict_traverse(PyObject *self, visitproc visit, void *arg)
 {
@@ -2198,7 +2237,7 @@ static PyTypeObject defdict_type = {
     0,                                  /* tp_setattr */
     0,                                  /* tp_as_async */
     (reprfunc)defdict_repr,             /* tp_repr */
-    0,                                  /* tp_as_number */
+    &defdict_as_number,                 /* tp_as_number */
     0,                                  /* tp_as_sequence */
     0,                                  /* tp_as_mapping */
     0,                                  /* tp_hash */