]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Generalize PySequence_Count() (operator.countOf) to work with iterators.
authorTim Peters <tim.peters@gmail.com>
Sat, 5 May 2001 11:33:43 +0000 (11:33 +0000)
committerTim Peters <tim.peters@gmail.com>
Sat, 5 May 2001 11:33:43 +0000 (11:33 +0000)
Lib/test/test_iter.py
Misc/NEWS
Objects/abstract.c

index bb9b102c360c7728bcadb8d4106ca8b4c98fe4f9..7d15e1cfb8f469a097a499d03c9e69cde32916ed 100644 (file)
@@ -527,4 +527,39 @@ class TestCase(unittest.TestCase):
             except OSError:
                 pass
 
+    # Test iterators with operator.countOf (PySequence_Count).
+    def test_countOf(self):
+        from operator import countOf
+        self.assertEqual(countOf([1,2,2,3,2,5], 2), 3)
+        self.assertEqual(countOf((1,2,2,3,2,5), 2), 3)
+        self.assertEqual(countOf("122325", "2"), 3)
+        self.assertEqual(countOf("122325", "6"), 0)
+
+        self.assertRaises(TypeError, countOf, 42, 1)
+        self.assertRaises(TypeError, countOf, countOf, countOf)
+
+        d = {"one": 3, "two": 3, "three": 3, 1j: 2j}
+        for k in d:
+            self.assertEqual(countOf(d, k), 1)
+        self.assertEqual(countOf(d.itervalues(), 3), 3)
+        self.assertEqual(countOf(d.itervalues(), 2j), 1)
+        self.assertEqual(countOf(d.itervalues(), 1j), 0)
+
+        f = open(TESTFN, "w")
+        try:
+            f.write("a\n" "b\n" "c\n" "b\n")
+        finally:
+            f.close()
+        f = open(TESTFN, "r")
+        try:
+            for letter, count in ("a", 1), ("b", 2), ("c", 1), ("d", 0):
+                f.seek(0, 0)
+                self.assertEqual(countOf(f, letter + "\n"), count)
+        finally:
+            f.close()
+            try:
+                unlink(TESTFN)
+            except OSError:
+                pass
+
 run_unittest(TestCase)
index 468eae69d75190d97b167379c7483af904802b05..aecc5e9114c4254996177265ba97de59cf0094b8 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -23,10 +23,12 @@ Core
     max()
     min()
     reduce()
+    tuple() (PySequence_Tuple() and PySequence_Fast() in C API)
     .join() method of strings
-    tuple()
+    'x in y' and 'x not in y' (PySequence_Contains() in C API)
+    operator.countOf() (PySequence_Count() in C API)
     XXX TODO zip()
-    'x in y' and 'x not in y'
+
 
 What's New in Python 2.1 (final)?
 =================================
index a0a40e89fc378295f8608afd9344868e1758986b..21c1ef1de46df1a171549edacdf7e1fb52996050 100644 (file)
@@ -1333,34 +1333,52 @@ PySequence_Fast(PyObject *v, const char *m)
        return v;
 }
 
+/* Return # of times o appears in s. */
 int
 PySequence_Count(PyObject *s, PyObject *o)
 {
-       int l, i, n, cmp, err;
-       PyObject *item;
+       int n;  /* running count of o hits */
+       PyObject *it;  /* iter(s) */
 
        if (s == NULL || o == NULL) {
                null_error();
                return -1;
        }
-       
-       l = PySequence_Size(s);
-       if (l < 0)
+
+       it = PyObject_GetIter(s);
+       if (it == NULL) {
+               type_error(".count() requires iterable argument");
                return -1;
+       }
 
        n = 0;
-       for (i = 0; i < l; i++) {
-               item = PySequence_GetItem(s, i);
-               if (item == NULL)
-                       return -1;
-               err = PyObject_Cmp(item, o, &cmp);
+       for (;;) {
+               int cmp;
+               PyObject *item = PyIter_Next(it);
+               if (item == NULL) {
+                       if (PyErr_Occurred())
+                               goto Fail;
+                       break;
+               }
+               cmp = PyObject_RichCompareBool(o, item, Py_EQ);
                Py_DECREF(item);
-               if (err < 0)
-                       return err;
-               if (cmp == 0)
+               if (cmp < 0)
+                       goto Fail;
+               if (cmp > 0) {
+                       if (n == INT_MAX) {
+                               PyErr_SetString(PyExc_OverflowError,
+                                               "count exceeds C int size");
+                               goto Fail;
+                       }
                        n++;
+               }
        }
+       Py_DECREF(it);
        return n;
+
+Fail:
+       Py_DECREF(it);
+       return -1;
 }
 
 /* Return -1 if error; 1 if v in w; 0 if v not in w. */