]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Issue #18594: Make the C code more closely match the pure python code.
authorRaymond Hettinger <python@rcn.com>
Fri, 4 Oct 2013 23:51:02 +0000 (16:51 -0700)
committerRaymond Hettinger <python@rcn.com>
Fri, 4 Oct 2013 23:51:02 +0000 (16:51 -0700)
Lib/test/test_collections.py
Modules/_collectionsmodule.c

index af27d22b4ed4b264dac7028aa45a7af8c190a32b..ff52755354b187f0c0218e581e3808e59d988447 100644 (file)
@@ -818,6 +818,24 @@ class TestCollectionABCs(ABCTestCase):
 ### Counter
 ################################################################################
 
+class CounterSubclassWithSetItem(Counter):
+    # Test a counter subclass that overrides __setitem__
+    def __init__(self, *args, **kwds):
+        self.called = False
+        Counter.__init__(self, *args, **kwds)
+    def __setitem__(self, key, value):
+        self.called = True
+        Counter.__setitem__(self, key, value)
+
+class CounterSubclassWithGet(Counter):
+    # Test a counter subclass that overrides get()
+    def __init__(self, *args, **kwds):
+        self.called = False
+        Counter.__init__(self, *args, **kwds)
+    def get(self, key, default):
+        self.called = True
+        return Counter.get(self, key, default)
+
 class TestCounter(unittest.TestCase):
 
     def test_basics(self):
@@ -1022,6 +1040,12 @@ class TestCounter(unittest.TestCase):
         self.assertEqual(m,
              OrderedDict([('a', 5), ('b', 2), ('r', 2), ('c', 1), ('d', 1)]))
 
+        # test fidelity to the pure python version
+        c = CounterSubclassWithSetItem('abracadabra')
+        self.assertTrue(c.called)
+        c = CounterSubclassWithGet('abracadabra')
+        self.assertTrue(c.called)
+
 
 ################################################################################
 ### OrderedDict
index b244667474730ea6185a6aedb31653b9c0209df7..c6c79836002237c0b86ce38fefc1a1b45e9a0b70 100644 (file)
@@ -1689,17 +1689,17 @@ Count elements in the iterable, updating the mappping");
 static PyObject *
 _count_elements(PyObject *self, PyObject *args)
 {
-    _Py_IDENTIFIER(__getitem__);
+    _Py_IDENTIFIER(get);
     _Py_IDENTIFIER(__setitem__);
     PyObject *it, *iterable, *mapping, *oldval;
     PyObject *newval = NULL;
     PyObject *key = NULL;
     PyObject *zero = NULL;
     PyObject *one = NULL;
-    PyObject *mapping_get = NULL;
-    PyObject *mapping_getitem;
+    PyObject *bound_get = NULL;
+    PyObject *mapping_get;
+    PyObject *dict_get;
     PyObject *mapping_setitem;
-    PyObject *dict_getitem;
     PyObject *dict_setitem;
 
     if (!PyArg_UnpackTuple(args, "_count_elements", 2, 2, &mapping, &iterable))
@@ -1713,15 +1713,16 @@ _count_elements(PyObject *self, PyObject *args)
     if (one == NULL)
         goto done;
 
-    mapping_getitem = _PyType_LookupId(Py_TYPE(mapping), &PyId___getitem__);
-    dict_getitem = _PyType_LookupId(&PyDict_Type, &PyId___getitem__);
+    /* Only take the fast path when get() and __setitem__()
+     * have not been overridden.
+     */
+    mapping_get = _PyType_LookupId(Py_TYPE(mapping), &PyId_get);
+    dict_get = _PyType_LookupId(&PyDict_Type, &PyId_get);
     mapping_setitem = _PyType_LookupId(Py_TYPE(mapping), &PyId___setitem__);
     dict_setitem = _PyType_LookupId(&PyDict_Type, &PyId___setitem__);
 
-    if (mapping_getitem != NULL &&
-        mapping_getitem == dict_getitem &&
-        mapping_setitem != NULL &&
-        mapping_setitem == dict_setitem) {
+    if (mapping_get != NULL && mapping_get == dict_get &&
+        mapping_setitem != NULL && mapping_setitem == dict_setitem) {
         while (1) {
             key = PyIter_Next(it);
             if (key == NULL)
@@ -1741,8 +1742,8 @@ _count_elements(PyObject *self, PyObject *args)
             Py_DECREF(key);
         }
     } else {
-        mapping_get = PyObject_GetAttrString(mapping, "get");
-        if (mapping_get == NULL)
+        bound_get = PyObject_GetAttrString(mapping, "get");
+        if (bound_get == NULL)
             goto done;
 
         zero = PyLong_FromLong(0);
@@ -1753,7 +1754,7 @@ _count_elements(PyObject *self, PyObject *args)
             key = PyIter_Next(it);
             if (key == NULL)
                 break;
-            oldval = PyObject_CallFunctionObjArgs(mapping_get, key, zero, NULL);
+            oldval = PyObject_CallFunctionObjArgs(bound_get, key, zero, NULL);
             if (oldval == NULL)
                 break;
             newval = PyNumber_Add(oldval, one);
@@ -1771,7 +1772,7 @@ done:
     Py_DECREF(it);
     Py_XDECREF(key);
     Py_XDECREF(newval);
-    Py_XDECREF(mapping_get);
+    Py_XDECREF(bound_get);
     Py_XDECREF(zero);
     Py_XDECREF(one);
     if (PyErr_Occurred())