]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Generalize filter(f, seq) to work with iterators. This also generalizes
authorTim Peters <tim.peters@gmail.com>
Wed, 2 May 2001 07:39:38 +0000 (07:39 +0000)
committerTim Peters <tim.peters@gmail.com>
Wed, 2 May 2001 07:39:38 +0000 (07:39 +0000)
filter() to no longer insist that len(seq) be defined.
NEEDS DOC CHANGES.

Lib/test/test_iter.py
Misc/NEWS
Python/bltinmodule.c

index 5b9bf652c441bef35d2fd7644659589a208719ed..952ab663522c1d27a8be4c2a210ef5c553189a93 100644 (file)
@@ -275,4 +275,48 @@ class TestCase(unittest.TestCase):
             except OSError:
                 pass
 
+    # Test filter()'s use of iterators.
+    def test_builtin_filter(self):
+        self.assertEqual(filter(None, SequenceClass(5)), range(1, 5))
+        self.assertEqual(filter(None, SequenceClass(0)), [])
+        self.assertEqual(filter(None, ()), ())
+        self.assertEqual(filter(None, "abc"), "abc")
+
+        d = {"one": 1, "two": 2, "three": 3}
+        self.assertEqual(filter(None, d), d.keys())
+
+        self.assertRaises(TypeError, filter, None, list)
+        self.assertRaises(TypeError, filter, None, 42)
+
+        class Boolean:
+            def __init__(self, truth):
+                self.truth = truth
+            def __nonzero__(self):
+                return self.truth
+        True = Boolean(1)
+        False = Boolean(0)
+
+        class Seq:
+            def __init__(self, *args):
+                self.vals = args
+            def __iter__(self):
+                class SeqIter:
+                    def __init__(self, vals):
+                        self.vals = vals
+                        self.i = 0
+                    def __iter__(self):
+                        return self
+                    def next(self):
+                        i = self.i
+                        self.i = i + 1
+                        if i < len(self.vals):
+                            return self.vals[i]
+                        else:
+                            raise StopIteration
+                return SeqIter(self.vals)
+
+        seq = Seq(*([True, False] * 25))
+        self.assertEqual(filter(lambda x: not x, seq), [False]*25)
+        self.assertEqual(filter(lambda x: not x, iter(seq)), [False]*25)
+
 run_unittest(TestCase)
index f121bac9c0e260a33698672bbc4a9a65402608aa..bbd2ac3e01494ce5cfb9ae6d94bc44d8ab2a4718 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -17,7 +17,8 @@ Core
 
 - The following functions were generalized to work nicely with iterator
   arguments:
-  list()
+    filter()
+    list()
 
 
 What's New in Python 2.1 (final)?
@@ -236,7 +237,7 @@ Tools
 - IDLE: syntax warnings in interactive mode are changed into errors.
 
 - Some improvements to Tools/webchecker (ignore some more URL types,
-  follow some more links). 
+  follow some more links).
 
 - Brought the Tools/compiler package up to date.
 
@@ -324,23 +325,23 @@ Python/C API
   in Flags and take an extra argument, a PyCompilerFlags *; examples:
   PyRun_AnyFileExFlags(), PyRun_InteractiveLoopFlags().  These
   variants may be removed in Python 2.2, when nested scopes are
-  mandatory. 
+  mandatory.
 
 Distutils
 
 - the sdist command now writes a PKG-INFO file, as described in PEP 241,
   into the release tree.
 
-- several enhancements to the bdist_wininst command from Thomas Heller 
+- several enhancements to the bdist_wininst command from Thomas Heller
   (an uninstaller, more customization of the installer's display)
 
 - from Jack Jansen: added Mac-specific code to generate a dialog for
   users to specify the command-line (because providing a command-line with
-  MacPython is awkward).  Jack also made various fixes for the Mac 
+  MacPython is awkward).  Jack also made various fixes for the Mac
   and the Metrowerks compiler.
-  
-- added 'platforms' and 'keywords' to the set of metadata that can be 
-  specified for a distribution.  
+
+- added 'platforms' and 'keywords' to the set of metadata that can be
+  specified for a distribution.
 
 - applied patches from Jason Tishler to make the compiler class work with
   Cygwin.
index 8572cc8183ed226fb06dd79e396a931b372cacbd..1051374fcf2f93fd1df599a128b83c98cc3fa36f 100644 (file)
@@ -162,53 +162,65 @@ Note that classes are callable, as are instances with a __call__() method.";
 static PyObject *
 builtin_filter(PyObject *self, PyObject *args)
 {
-       PyObject *func, *seq, *result;
-       PySequenceMethods *sqf;
-       int len;
+       PyObject *func, *seq, *result, *it;
+       int len;   /* guess for result list size */
        register int i, j;
 
        if (!PyArg_ParseTuple(args, "OO:filter", &func, &seq))
                return NULL;
 
-       if (PyString_Check(seq)) {
-               PyObject *r = filterstring(func, seq);
-               return r;
-       }
+       /* Strings and tuples return a result of the same type. */
+       if (PyString_Check(seq))
+               return filterstring(func, seq);
+       if (PyTuple_Check(seq))
+               return filtertuple(func, seq);
 
-       if (PyTuple_Check(seq)) {
-               PyObject *r = filtertuple(func, seq);
-               return r;
-       }
+       /* Get iterator. */
+       it = PyObject_GetIter(seq);
+       if (it == NULL)
+               return NULL;
 
-       sqf = seq->ob_type->tp_as_sequence;
-       if (sqf == NULL || sqf->sq_length == NULL || sqf->sq_item == NULL) {
-               PyErr_SetString(PyExc_TypeError,
-                          "filter() arg 2 must be a sequence");
-               goto Fail_2;
+       /* Guess a result list size. */
+       len = -1;   /* unknown */
+       if (PySequence_Check(seq) &&
+           seq->ob_type->tp_as_sequence->sq_length) {
+               len = PySequence_Size(seq);
+               if (len < 0)
+                       PyErr_Clear();
        }
+       if (len < 0)
+               len = 8;  /* arbitrary */
 
-       if ((len = (*sqf->sq_length)(seq)) < 0)
-               goto Fail_2;
-
+       /* Get a result list. */
        if (PyList_Check(seq) && seq->ob_refcnt == 1) {
+               /* Eww - can modify the list in-place. */
                Py_INCREF(seq);
                result = seq;
        }
        else {
-               if ((result = PyList_New(len)) == NULL)
-                       goto Fail_2;
+               result = PyList_New(len);
+               if (result == NULL)
+                       goto Fail_it;
        }
 
+       /* Build the result list. */
        for (i = j = 0; ; ++i) {
                PyObject *item, *good;
                int ok;
 
-               if ((item = (*sqf->sq_item)(seq, i)) == NULL) {
-                       if (PyErr_ExceptionMatches(PyExc_IndexError)) {
-                               PyErr_Clear();
-                               break;
+               item = PyIter_Next(it);
+               if (item == NULL) {
+                       /* We're out of here in any case, but if this is a
+                        * StopIteration exception it's expected, but if
+                        * any other kind of exception it's an error.
+                        */
+                       if (PyErr_Occurred()) {
+                               if (PyErr_ExceptionMatches(PyExc_StopIteration))
+                                       PyErr_Clear();
+                               else
+                                       goto Fail_result_it;
                        }
-                       goto Fail_1;
+                       break;
                }
 
                if (func == Py_None) {
@@ -217,43 +229,45 @@ builtin_filter(PyObject *self, PyObject *args)
                }
                else {
                        PyObject *arg = Py_BuildValue("(O)", item);
-                       if (arg == NULL)
-                               goto Fail_1;
+                       if (arg == NULL) {
+                               Py_DECREF(item);
+                               goto Fail_result_it;
+                       }
                        good = PyEval_CallObject(func, arg);
                        Py_DECREF(arg);
                        if (good == NULL) {
                                Py_DECREF(item);
-                               goto Fail_1;
+                               goto Fail_result_it;
                        }
                }
                ok = PyObject_IsTrue(good);
                Py_DECREF(good);
                if (ok) {
-                       if (j < len) {
-                               if (PyList_SetItem(result, j++, item) < 0)
-                                       goto Fail_1;
-                       }
+                       if (j < len)
+                               PyList_SET_ITEM(result, j, item);
                        else {
                                int status = PyList_Append(result, item);
-                               j++;
                                Py_DECREF(item);
                                if (status < 0)
-                                       goto Fail_1;
+                                       goto Fail_result_it;
                        }
-               } else {
-                       Py_DECREF(item);
+                       ++j;
                }
+               else
+                       Py_DECREF(item);
        }
 
 
+       /* Cut back result list if len is too big. */
        if (j < len && PyList_SetSlice(result, j, len, NULL) < 0)
-               goto Fail_1;
+               goto Fail_result_it;
 
        return result;
 
-Fail_1:
+Fail_result_it:
        Py_DECREF(result);
-Fail_2:
+Fail_it:
+       Py_DECREF(it);
        return NULL;
 }