]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-119793: Add optional length-checking to `map()` (GH-120471)
authorNice Zombies <nineteendo19d0@gmail.com>
Mon, 4 Nov 2024 14:00:19 +0000 (15:00 +0100)
committerGitHub <noreply@github.com>
Mon, 4 Nov 2024 14:00:19 +0000 (15:00 +0100)
Co-authored-by: Bénédikt Tran <10796600+picnixz@users.noreply.github.com>
Co-authored-by: Pieter Eendebak <pieter.eendebak@gmail.com>
Co-authored-by: Erlend E. Aasland <erlend.aasland@protonmail.com>
Co-authored-by: Raymond Hettinger <rhettinger@users.noreply.github.com>
Doc/library/functions.rst
Doc/whatsnew/3.14.rst
Lib/test/test_builtin.py
Lib/test/test_itertools.py
Misc/NEWS.d/next/Core_and_Builtins/2024-06-13-19-12-49.gh-issue-119793.FDVCDk.rst [new file with mode: 0644]
Python/bltinmodule.c

index 03fc41fa793977785b6ab40a98a7f8ce4853d501..a7549b9bce76e2026e9d1c9cf9cb6d9b3b2b5af2 100644 (file)
@@ -1205,14 +1205,19 @@ are always available.  They are listed here in alphabetical order.
       unchanged from previous versions.
 
 
-.. function:: map(function, iterable, *iterables)
+.. function:: map(function, iterable, /, *iterables, strict=False)
 
    Return an iterator that applies *function* to every item of *iterable*,
    yielding the results.  If additional *iterables* arguments are passed,
    *function* must take that many arguments and is applied to the items from all
    iterables in parallel.  With multiple iterables, the iterator stops when the
-   shortest iterable is exhausted.  For cases where the function inputs are
-   already arranged into argument tuples, see :func:`itertools.starmap`\.
+   shortest iterable is exhausted.  If *strict* is ``True`` and one of the
+   iterables is exhausted before the others, a :exc:`ValueError` is raised. For
+   cases where the function inputs are already arranged into argument tuples,
+   see :func:`itertools.starmap`.
+
+   .. versionchanged:: 3.14
+      Added the *strict* parameter.
 
 
 .. function:: max(iterable, *, key=None)
index deee683d7b87b1cf942cef1cb0d6f578b15eda1b..80c1a93b95a6afa782b87221da7649348bf42627 100644 (file)
@@ -175,6 +175,10 @@ Improved error messages
 Other language changes
 ======================
 
+* The :func:`map` built-in now has an optional keyword-only *strict* flag
+  like :func:`zip` to check that all the iterables are of equal length.
+  (Contributed by Wannes Boeykens in :gh:`119793`.)
+
 * Incorrect usage of :keyword:`await` and asynchronous comprehensions
   is now detected even if the code is optimized away by the :option:`-O`
   command-line option. For example, ``python -O -c 'assert await 1'``
index eb5906f8944c8efea7f2ddd48e0da44ffbf25624..f8e6f05cd607c847e5c95ff78b5a67572ad7c731 100644 (file)
@@ -148,6 +148,9 @@ def filter_char(arg):
 def map_char(arg):
     return chr(ord(arg)+1)
 
+def pack(*args):
+    return args
+
 class BuiltinTest(unittest.TestCase):
     # Helper to check picklability
     def check_iter_pickle(self, it, seq, proto):
@@ -1269,6 +1272,108 @@ class BuiltinTest(unittest.TestCase):
             m2 = map(map_char, "Is this the real life?")
             self.check_iter_pickle(m1, list(m2), proto)
 
+    # strict map tests based on strict zip tests
+
+    def test_map_pickle_strict(self):
+        a = (1, 2, 3)
+        b = (4, 5, 6)
+        t = [(1, 4), (2, 5), (3, 6)]
+        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+            m1 = map(pack, a, b, strict=True)
+            self.check_iter_pickle(m1, t, proto)
+
+    def test_map_pickle_strict_fail(self):
+        a = (1, 2, 3)
+        b = (4, 5, 6, 7)
+        t = [(1, 4), (2, 5), (3, 6)]
+        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+            m1 = map(pack, a, b, strict=True)
+            m2 = pickle.loads(pickle.dumps(m1, proto))
+            self.assertEqual(self.iter_error(m1, ValueError), t)
+            self.assertEqual(self.iter_error(m2, ValueError), t)
+
+    def test_map_strict(self):
+        self.assertEqual(tuple(map(pack, (1, 2, 3), 'abc', strict=True)),
+                         ((1, 'a'), (2, 'b'), (3, 'c')))
+        self.assertRaises(ValueError, tuple,
+                          map(pack, (1, 2, 3, 4), 'abc', strict=True))
+        self.assertRaises(ValueError, tuple,
+                          map(pack, (1, 2), 'abc', strict=True))
+        self.assertRaises(ValueError, tuple,
+                          map(pack, (1, 2), (1, 2), 'abc', strict=True))
+
+    def test_map_strict_iterators(self):
+        x = iter(range(5))
+        y = [0]
+        z = iter(range(5))
+        self.assertRaises(ValueError, list,
+                          (map(pack, x, y, z, strict=True)))
+        self.assertEqual(next(x), 2)
+        self.assertEqual(next(z), 1)
+
+    def test_map_strict_error_handling(self):
+
+        class Error(Exception):
+            pass
+
+        class Iter:
+            def __init__(self, size):
+                self.size = size
+            def __iter__(self):
+                return self
+            def __next__(self):
+                self.size -= 1
+                if self.size < 0:
+                    raise Error
+                return self.size
+
+        l1 = self.iter_error(map(pack, "AB", Iter(1), strict=True), Error)
+        self.assertEqual(l1, [("A", 0)])
+        l2 = self.iter_error(map(pack, "AB", Iter(2), "A", strict=True), ValueError)
+        self.assertEqual(l2, [("A", 1, "A")])
+        l3 = self.iter_error(map(pack, "AB", Iter(2), "ABC", strict=True), Error)
+        self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")])
+        l4 = self.iter_error(map(pack, "AB", Iter(3), strict=True), ValueError)
+        self.assertEqual(l4, [("A", 2), ("B", 1)])
+        l5 = self.iter_error(map(pack, Iter(1), "AB", strict=True), Error)
+        self.assertEqual(l5, [(0, "A")])
+        l6 = self.iter_error(map(pack, Iter(2), "A", strict=True), ValueError)
+        self.assertEqual(l6, [(1, "A")])
+        l7 = self.iter_error(map(pack, Iter(2), "ABC", strict=True), Error)
+        self.assertEqual(l7, [(1, "A"), (0, "B")])
+        l8 = self.iter_error(map(pack, Iter(3), "AB", strict=True), ValueError)
+        self.assertEqual(l8, [(2, "A"), (1, "B")])
+
+    def test_map_strict_error_handling_stopiteration(self):
+
+        class Iter:
+            def __init__(self, size):
+                self.size = size
+            def __iter__(self):
+                return self
+            def __next__(self):
+                self.size -= 1
+                if self.size < 0:
+                    raise StopIteration
+                return self.size
+
+        l1 = self.iter_error(map(pack, "AB", Iter(1), strict=True), ValueError)
+        self.assertEqual(l1, [("A", 0)])
+        l2 = self.iter_error(map(pack, "AB", Iter(2), "A", strict=True), ValueError)
+        self.assertEqual(l2, [("A", 1, "A")])
+        l3 = self.iter_error(map(pack, "AB", Iter(2), "ABC", strict=True), ValueError)
+        self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")])
+        l4 = self.iter_error(map(pack, "AB", Iter(3), strict=True), ValueError)
+        self.assertEqual(l4, [("A", 2), ("B", 1)])
+        l5 = self.iter_error(map(pack, Iter(1), "AB", strict=True), ValueError)
+        self.assertEqual(l5, [(0, "A")])
+        l6 = self.iter_error(map(pack, Iter(2), "A", strict=True), ValueError)
+        self.assertEqual(l6, [(1, "A")])
+        l7 = self.iter_error(map(pack, Iter(2), "ABC", strict=True), ValueError)
+        self.assertEqual(l7, [(1, "A"), (0, "B")])
+        l8 = self.iter_error(map(pack, Iter(3), "AB", strict=True), ValueError)
+        self.assertEqual(l8, [(2, "A"), (1, "B")])
+
     def test_max(self):
         self.assertEqual(max('123123'), '3')
         self.assertEqual(max(1, 2, 3), 3)
index 8469de998ba014efc81323183a8a71ebbec9fd66..a52e1d3fa142d98fc8bbaee27c4f279a5541a224 100644 (file)
@@ -2433,10 +2433,10 @@ class SubclassWithKwargsTest(unittest.TestCase):
                     subclass(*args, newarg=3)
 
         for cls, args, result in testcases:
-            # Constructors of repeat, zip, compress accept keyword arguments.
+            # Constructors of repeat, zip, map, compress accept keyword arguments.
             # Their subclasses need overriding __new__ to support new
             # keyword arguments.
-            if cls in [repeat, zip, compress]:
+            if cls in [repeat, zip, map, compress]:
                 continue
             with self.subTest(cls):
                 class subclass_with_init(cls):
diff --git a/Misc/NEWS.d/next/Core_and_Builtins/2024-06-13-19-12-49.gh-issue-119793.FDVCDk.rst b/Misc/NEWS.d/next/Core_and_Builtins/2024-06-13-19-12-49.gh-issue-119793.FDVCDk.rst
new file mode 100644 (file)
index 0000000..976d671
--- /dev/null
@@ -0,0 +1,3 @@
+The :func:`map` built-in now has an optional keyword-only *strict* flag
+like :func:`zip` to check that all the iterables are of equal length.
+Patch by Wannes Boeykens.
index 12f065d4b4f13871551dc14e7dfee689181c47d3..85a28de2bb9d13ae41228c18b8092919d4060c49 100644 (file)
@@ -1311,6 +1311,7 @@ typedef struct {
     PyObject_HEAD
     PyObject *iters;
     PyObject *func;
+    int strict;
 } mapobject;
 
 static PyObject *
@@ -1319,10 +1320,21 @@ map_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
     PyObject *it, *iters, *func;
     mapobject *lz;
     Py_ssize_t numargs, i;
+    int strict = 0;
 
-    if ((type == &PyMap_Type || type->tp_init == PyMap_Type.tp_init) &&
-        !_PyArg_NoKeywords("map", kwds))
-        return NULL;
+    if (kwds) {
+        PyObject *empty = PyTuple_New(0);
+        if (empty == NULL) {
+            return NULL;
+        }
+        static char *kwlist[] = {"strict", NULL};
+        int parsed = PyArg_ParseTupleAndKeywords(
+                empty, kwds, "|$p:map", kwlist, &strict);
+        Py_DECREF(empty);
+        if (!parsed) {
+            return NULL;
+        }
+    }
 
     numargs = PyTuple_Size(args);
     if (numargs < 2) {
@@ -1354,6 +1366,7 @@ map_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
     lz->iters = iters;
     func = PyTuple_GET_ITEM(args, 0);
     lz->func = Py_NewRef(func);
+    lz->strict = strict;
 
     return (PyObject *)lz;
 }
@@ -1363,11 +1376,14 @@ map_vectorcall(PyObject *type, PyObject * const*args,
                 size_t nargsf, PyObject *kwnames)
 {
     PyTypeObject *tp = _PyType_CAST(type);
-    if (tp == &PyMap_Type && !_PyArg_NoKwnames("map", kwnames)) {
-        return NULL;
-    }
 
     Py_ssize_t nargs = PyVectorcall_NARGS(nargsf);
+    if (kwnames != NULL && PyTuple_GET_SIZE(kwnames) != 0) {
+        // Fallback to map_new()
+        PyThreadState *tstate = _PyThreadState_GET();
+        return _PyObject_MakeTpCall(tstate, type, args, nargs, kwnames);
+    }
+
     if (nargs < 2) {
         PyErr_SetString(PyExc_TypeError,
            "map() must have at least two arguments.");
@@ -1395,6 +1411,7 @@ map_vectorcall(PyObject *type, PyObject * const*args,
     }
     lz->iters = iters;
     lz->func = Py_NewRef(args[0]);
+    lz->strict = 0;
 
     return (PyObject *)lz;
 }
@@ -1419,6 +1436,7 @@ map_traverse(mapobject *lz, visitproc visit, void *arg)
 static PyObject *
 map_next(mapobject *lz)
 {
+    Py_ssize_t i;
     PyObject *small_stack[_PY_FASTCALL_SMALL_STACK];
     PyObject **stack;
     PyObject *result = NULL;
@@ -1437,10 +1455,13 @@ map_next(mapobject *lz)
     }
 
     Py_ssize_t nargs = 0;
-    for (Py_ssize_t i=0; i < niters; i++) {
+    for (i=0; i < niters; i++) {
         PyObject *it = PyTuple_GET_ITEM(lz->iters, i);
         PyObject *val = Py_TYPE(it)->tp_iternext(it);
         if (val == NULL) {
+            if (lz->strict) {
+                goto check;
+            }
             goto exit;
         }
         stack[i] = val;
@@ -1450,13 +1471,50 @@ map_next(mapobject *lz)
     result = _PyObject_VectorcallTstate(tstate, lz->func, stack, nargs, NULL);
 
 exit:
-    for (Py_ssize_t i=0; i < nargs; i++) {
+    for (i=0; i < nargs; i++) {
         Py_DECREF(stack[i]);
     }
     if (stack != small_stack) {
         PyMem_Free(stack);
     }
     return result;
+check:
+    if (PyErr_Occurred()) {
+        if (!PyErr_ExceptionMatches(PyExc_StopIteration)) {
+            // next() on argument i raised an exception (not StopIteration)
+            return NULL;
+        }
+        PyErr_Clear();
+    }
+    if (i) {
+        // ValueError: map() argument 2 is shorter than argument 1
+        // ValueError: map() argument 3 is shorter than arguments 1-2
+        const char* plural = i == 1 ? " " : "s 1-";
+        return PyErr_Format(PyExc_ValueError,
+                            "map() argument %d is shorter than argument%s%d",
+                            i + 1, plural, i);
+    }
+    for (i = 1; i < niters; i++) {
+        PyObject *it = PyTuple_GET_ITEM(lz->iters, i);
+        PyObject *val = (*Py_TYPE(it)->tp_iternext)(it);
+        if (val) {
+            Py_DECREF(val);
+            const char* plural = i == 1 ? " " : "s 1-";
+            return PyErr_Format(PyExc_ValueError,
+                                "map() argument %d is longer than argument%s%d",
+                                i + 1, plural, i);
+        }
+        if (PyErr_Occurred()) {
+            if (!PyErr_ExceptionMatches(PyExc_StopIteration)) {
+                // next() on argument i raised an exception (not StopIteration)
+                return NULL;
+            }
+            PyErr_Clear();
+        }
+        // Argument i is exhausted. So far so good...
+    }
+    // All arguments are exhausted. Success!
+    goto exit;
 }
 
 static PyObject *
@@ -1473,21 +1531,41 @@ map_reduce(mapobject *lz, PyObject *Py_UNUSED(ignored))
         PyTuple_SET_ITEM(args, i+1, Py_NewRef(it));
     }
 
+    if (lz->strict) {
+        return Py_BuildValue("ONO", Py_TYPE(lz), args, Py_True);
+    }
     return Py_BuildValue("ON", Py_TYPE(lz), args);
 }
 
+PyDoc_STRVAR(setstate_doc, "Set state information for unpickling.");
+
+static PyObject *
+map_setstate(mapobject *lz, PyObject *state)
+{
+    int strict = PyObject_IsTrue(state);
+    if (strict < 0) {
+        return NULL;
+    }
+    lz->strict = strict;
+    Py_RETURN_NONE;
+}
+
 static PyMethodDef map_methods[] = {
     {"__reduce__", _PyCFunction_CAST(map_reduce), METH_NOARGS, reduce_doc},
+    {"__setstate__", _PyCFunction_CAST(map_setstate), METH_O, setstate_doc},
     {NULL,           NULL}           /* sentinel */
 };
 
 
 PyDoc_STRVAR(map_doc,
-"map(function, iterable, /, *iterables)\n\
+"map(function, iterable, /, *iterables, strict=False)\n\
 --\n\
 \n\
 Make an iterator that computes the function using arguments from\n\
-each of the iterables.  Stops when the shortest iterable is exhausted.");
+each of the iterables.  Stops when the shortest iterable is exhausted.\n\
+\n\
+If strict is true and one of the arguments is exhausted before the others,\n\
+raise a ValueError.");
 
 PyTypeObject PyMap_Type = {
     PyVarObject_HEAD_INIT(&PyType_Type, 0)
@@ -3068,8 +3146,6 @@ zip_reduce(zipobject *lz, PyObject *Py_UNUSED(ignored))
     return PyTuple_Pack(2, Py_TYPE(lz), lz->ittuple);
 }
 
-PyDoc_STRVAR(setstate_doc, "Set state information for unpickling.");
-
 static PyObject *
 zip_setstate(zipobject *lz, PyObject *state)
 {