]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-113202: Add a strict option to itertools.batched() (gh-113203)
authorRaymond Hettinger <rhettinger@users.noreply.github.com>
Sat, 16 Dec 2023 15:13:50 +0000 (09:13 -0600)
committerGitHub <noreply@github.com>
Sat, 16 Dec 2023 15:13:50 +0000 (09:13 -0600)
Doc/library/itertools.rst
Lib/test/test_itertools.py
Misc/NEWS.d/next/Library/2023-12-15-18-10-26.gh-issue-113202.xv_Ww8.rst [new file with mode: 0644]
Modules/clinic/itertoolsmodule.c.h
Modules/itertoolsmodule.c

index 6bcda307f256f21e25526f1a4e2a8507926a7d13..c016fb76bfd0a072465a9675c8a9b8fe007bed7f 100644 (file)
@@ -164,11 +164,14 @@ loops that truncate the stream.
        Added the optional *initial* parameter.
 
 
-.. function:: batched(iterable, n)
+.. function:: batched(iterable, n, *, strict=False)
 
    Batch data from the *iterable* into tuples of length *n*. The last
    batch may be shorter than *n*.
 
+   If *strict* is true, will raise a :exc:`ValueError` if the final
+   batch is shorter than *n*.
+
    Loops over the input iterable and accumulates data into tuples up to
    size *n*.  The input is consumed lazily, just enough to fill a batch.
    The result is yielded as soon as the batch is full or when the input
@@ -190,16 +193,21 @@ loops that truncate the stream.
 
    Roughly equivalent to::
 
-      def batched(iterable, n):
+      def batched(iterable, n, *, strict=False):
           # batched('ABCDEFG', 3) --> ABC DEF G
           if n < 1:
               raise ValueError('n must be at least one')
           it = iter(iterable)
           while batch := tuple(islice(it, n)):
+              if strict and len(batch) != n:
+                  raise ValueError('batched(): incomplete batch')
               yield batch
 
    .. versionadded:: 3.12
 
+   .. versionchanged:: 3.13
+      Added the *strict* option.
+
 
 .. function:: chain(*iterables)
 
@@ -1039,7 +1047,7 @@ The following recipes have a more mathematical flavor:
    def reshape(matrix, cols):
        "Reshape a 2-D matrix to have a given number of columns."
        # reshape([(0, 1), (2, 3), (4, 5)], 3) -->  (0, 1, 2), (3, 4, 5)
-       return batched(chain.from_iterable(matrix), cols)
+       return batched(chain.from_iterable(matrix), cols, strict=True)
 
    def transpose(matrix):
        "Swap the rows and columns of a 2-D matrix."
@@ -1270,6 +1278,10 @@ The following recipes have a more mathematical flavor:
     [(0, 1, 2), (3, 4, 5), (6, 7, 8), (9, 10, 11)]
     >>> list(reshape(M, 4))
     [(0, 1, 2, 3), (4, 5, 6, 7), (8, 9, 10, 11)]
+    >>> list(reshape(M, 5))
+    Traceback (most recent call last):
+    ...
+    ValueError: batched(): incomplete batch
     >>> list(reshape(M, 6))
     [(0, 1, 2, 3, 4, 5), (6, 7, 8, 9, 10, 11)]
     >>> list(reshape(M, 12))
index 705e880d98685e5b3523c1daabacecd4cf239e74..9af0730ea9800490e69bee36e8b145d5f4a71dbc 100644 (file)
@@ -187,7 +187,11 @@ class TestBasicOps(unittest.TestCase):
                              [('A', 'B'), ('C', 'D'), ('E', 'F'), ('G',)])
         self.assertEqual(list(batched('ABCDEFG', 1)),
                             [('A',), ('B',), ('C',), ('D',), ('E',), ('F',), ('G',)])
+        self.assertEqual(list(batched('ABCDEF', 2, strict=True)),
+                             [('A', 'B'), ('C', 'D'), ('E', 'F')])
 
+        with self.assertRaises(ValueError):         # Incomplete batch when strict
+            list(batched('ABCDEFG', 3, strict=True))
         with self.assertRaises(TypeError):          # Too few arguments
             list(batched('ABCDEFG'))
         with self.assertRaises(TypeError):
diff --git a/Misc/NEWS.d/next/Library/2023-12-15-18-10-26.gh-issue-113202.xv_Ww8.rst b/Misc/NEWS.d/next/Library/2023-12-15-18-10-26.gh-issue-113202.xv_Ww8.rst
new file mode 100644 (file)
index 0000000..44f26ae
--- /dev/null
@@ -0,0 +1 @@
+Add a ``strict`` option to ``batched()`` in the ``itertools`` module.
index fa2c5e0e9223870a81b119d46b845448cf33d724..3ec479943a83d4d9e86021509b4128e8e46385ef 100644 (file)
@@ -10,7 +10,7 @@ preserve
 #include "pycore_modsupport.h"    // _PyArg_UnpackKeywords()
 
 PyDoc_STRVAR(batched_new__doc__,
-"batched(iterable, n)\n"
+"batched(iterable, n, *, strict=False)\n"
 "--\n"
 "\n"
 "Batch data into tuples of length n. The last batch may be shorter than n.\n"
@@ -25,10 +25,14 @@ PyDoc_STRVAR(batched_new__doc__,
 "    ...\n"
 "    (\'A\', \'B\', \'C\')\n"
 "    (\'D\', \'E\', \'F\')\n"
-"    (\'G\',)");
+"    (\'G\',)\n"
+"\n"
+"If \"strict\" is True, raises a ValueError if the final batch is shorter\n"
+"than n.");
 
 static PyObject *
-batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n);
+batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n,
+                 int strict);
 
 static PyObject *
 batched_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
@@ -36,14 +40,14 @@ batched_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
     PyObject *return_value = NULL;
     #if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE)
 
-    #define NUM_KEYWORDS 2
+    #define NUM_KEYWORDS 3
     static struct {
         PyGC_Head _this_is_not_used;
         PyObject_VAR_HEAD
         PyObject *ob_item[NUM_KEYWORDS];
     } _kwtuple = {
         .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS)
-        .ob_item = { &_Py_ID(iterable), &_Py_ID(n), },
+        .ob_item = { &_Py_ID(iterable), &_Py_ID(n), &_Py_ID(strict), },
     };
     #undef NUM_KEYWORDS
     #define KWTUPLE (&_kwtuple.ob_base.ob_base)
@@ -52,18 +56,20 @@ batched_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
     #  define KWTUPLE NULL
     #endif  // !Py_BUILD_CORE
 
-    static const char * const _keywords[] = {"iterable", "n", NULL};
+    static const char * const _keywords[] = {"iterable", "n", "strict", NULL};
     static _PyArg_Parser _parser = {
         .keywords = _keywords,
         .fname = "batched",
         .kwtuple = KWTUPLE,
     };
     #undef KWTUPLE
-    PyObject *argsbuf[2];
+    PyObject *argsbuf[3];
     PyObject * const *fastargs;
     Py_ssize_t nargs = PyTuple_GET_SIZE(args);
+    Py_ssize_t noptargs = nargs + (kwargs ? PyDict_GET_SIZE(kwargs) : 0) - 2;
     PyObject *iterable;
     Py_ssize_t n;
+    int strict = 0;
 
     fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser, 2, 2, 0, argsbuf);
     if (!fastargs) {
@@ -82,7 +88,15 @@ batched_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
         }
         n = ival;
     }
-    return_value = batched_new_impl(type, iterable, n);
+    if (!noptargs) {
+        goto skip_optional_kwonly;
+    }
+    strict = PyObject_IsTrue(fastargs[2]);
+    if (strict < 0) {
+        goto exit;
+    }
+skip_optional_kwonly:
+    return_value = batched_new_impl(type, iterable, n, strict);
 
 exit:
     return return_value;
@@ -914,4 +928,4 @@ skip_optional_pos:
 exit:
     return return_value;
 }
-/*[clinic end generated code: output=782fe7e30733779b input=a9049054013a1b77]*/
+/*[clinic end generated code: output=c6a515f765da86b5 input=a9049054013a1b77]*/
index ab99fa4d873bf510eb377bf5615303ec16340c31..164741495c7baf92167cfc42e5e912e79316ce95 100644 (file)
@@ -105,20 +105,11 @@ class itertools.pairwise "pairwiseobject *" "clinic_state()->pairwise_type"
 
 /* batched object ************************************************************/
 
-/* Note:  The built-in zip() function includes a "strict" argument
-   that was needed because that function would silently truncate data,
-   and there was no easy way for a user to detect the data loss.
-   The same reasoning does not apply to batched() which never drops data.
-   Instead, batched() produces a shorter tuple which can be handled
-   as the user sees fit.  If requested, it would be reasonable to add
-   "fillvalue" support which had demonstrated value in zip_longest().
-   For now, the API is kept simple and clean.
- */
-
 typedef struct {
     PyObject_HEAD
     PyObject *it;
     Py_ssize_t batch_size;
+    bool strict;
 } batchedobject;
 
 /*[clinic input]
@@ -126,6 +117,9 @@ typedef struct {
 itertools.batched.__new__ as batched_new
     iterable: object
     n: Py_ssize_t
+    *
+    strict: bool = False
+
 Batch data into tuples of length n. The last batch may be shorter than n.
 
 Loops over the input iterable and accumulates data into tuples
@@ -140,11 +134,15 @@ or when the input iterable is exhausted.
     ('D', 'E', 'F')
     ('G',)
 
+If "strict" is True, raises a ValueError if the final batch is shorter
+than n.
+
 [clinic start generated code]*/
 
 static PyObject *
-batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n)
-/*[clinic end generated code: output=7ebc954d655371b6 input=ffd70726927c5129]*/
+batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n,
+                 int strict)
+/*[clinic end generated code: output=c6de11b061529d3e input=7814b47e222f5467]*/
 {
     PyObject *it;
     batchedobject *bo;
@@ -170,6 +168,7 @@ batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n)
     }
     bo->batch_size = n;
     bo->it = it;
+    bo->strict = (bool) strict;
     return (PyObject *)bo;
 }
 
@@ -233,6 +232,12 @@ batched_next(batchedobject *bo)
         Py_DECREF(result);
         return NULL;
     }
+    if (bo->strict) {
+        Py_CLEAR(bo->it);
+        Py_DECREF(result);
+        PyErr_SetString(PyExc_ValueError, "batched(): incomplete batch");
+        return NULL;
+    }
     _PyTuple_Resize(&result, i);
     return result;
 }