]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Issue #22955: attrgetter, itemgetter and methodcaller objects in the operator
authorSerhiy Storchaka <storchaka@gmail.com>
Wed, 20 May 2015 15:29:18 +0000 (18:29 +0300)
committerSerhiy Storchaka <storchaka@gmail.com>
Wed, 20 May 2015 15:29:18 +0000 (18:29 +0300)
module now support pickling.  Added readable and evaluable repr for these
objects.  Based on patch by Josh Rosenberg.

Lib/operator.py
Lib/test/test_operator.py
Misc/NEWS
Modules/_operator.c

index 856036ddf12c35b38991beee477c50fd805fc1c4..0db51c155648c2b2fc1841c46eb4016717ea5c9e 100644 (file)
@@ -231,10 +231,13 @@ class attrgetter:
     After h = attrgetter('name.first', 'name.last'), the call h(r) returns
     (r.name.first, r.name.last).
     """
+    __slots__ = ('_attrs', '_call')
+
     def __init__(self, attr, *attrs):
         if not attrs:
             if not isinstance(attr, str):
                 raise TypeError('attribute name must be a string')
+            self._attrs = (attr,)
             names = attr.split('.')
             def func(obj):
                 for name in names:
@@ -242,7 +245,8 @@ class attrgetter:
                 return obj
             self._call = func
         else:
-            getters = tuple(map(attrgetter, (attr,) + attrs))
+            self._attrs = (attr,) + attrs
+            getters = tuple(map(attrgetter, self._attrs))
             def func(obj):
                 return tuple(getter(obj) for getter in getters)
             self._call = func
@@ -250,19 +254,30 @@ class attrgetter:
     def __call__(self, obj):
         return self._call(obj)
 
+    def __repr__(self):
+        return '%s.%s(%s)' % (self.__class__.__module__,
+                              self.__class__.__qualname__,
+                              ', '.join(map(repr, self._attrs)))
+
+    def __reduce__(self):
+        return self.__class__, self._attrs
+
 class itemgetter:
     """
     Return a callable object that fetches the given item(s) from its operand.
     After f = itemgetter(2), the call f(r) returns r[2].
     After g = itemgetter(2, 5, 3), the call g(r) returns (r[2], r[5], r[3])
     """
+    __slots__ = ('_items', '_call')
+
     def __init__(self, item, *items):
         if not items:
+            self._items = (item,)
             def func(obj):
                 return obj[item]
             self._call = func
         else:
-            items = (item,) + items
+            self._items = items = (item,) + items
             def func(obj):
                 return tuple(obj[i] for i in items)
             self._call = func
@@ -270,6 +285,14 @@ class itemgetter:
     def __call__(self, obj):
         return self._call(obj)
 
+    def __repr__(self):
+        return '%s.%s(%s)' % (self.__class__.__module__,
+                              self.__class__.__name__,
+                              ', '.join(map(repr, self._items)))
+
+    def __reduce__(self):
+        return self.__class__, self._items
+
 class methodcaller:
     """
     Return a callable object that calls the given method on its operand.
@@ -277,6 +300,7 @@ class methodcaller:
     After g = methodcaller('name', 'date', foo=1), the call g(r) returns
     r.name('date', foo=1).
     """
+    __slots__ = ('_name', '_args', '_kwargs')
 
     def __init__(*args, **kwargs):
         if len(args) < 2:
@@ -284,12 +308,30 @@ class methodcaller:
             raise TypeError(msg)
         self = args[0]
         self._name = args[1]
+        if not isinstance(self._name, str):
+            raise TypeError('method name must be a string')
         self._args = args[2:]
         self._kwargs = kwargs
 
     def __call__(self, obj):
         return getattr(obj, self._name)(*self._args, **self._kwargs)
 
+    def __repr__(self):
+        args = [repr(self._name)]
+        args.extend(map(repr, self._args))
+        args.extend('%s=%r' % (k, v) for k, v in self._kwargs.items())
+        return '%s.%s(%s)' % (self.__class__.__module__,
+                              self.__class__.__name__,
+                              ', '.join(args))
+
+    def __reduce__(self):
+        if not self._kwargs:
+            return self.__class__, (self._name,) + self._args
+        else:
+            from functools import partial
+            return partial(self.__class__, self._name, **self._kwargs), self._args
+
+
 # In-place Operations *********************************************************#
 
 def iadd(a, b):
index 1bd0391ee2976afffc96ce4c6e19624d15aac858..ef9cf3e5cd98514f266fb0882e93bb44e65dbaae 100644 (file)
@@ -1,4 +1,6 @@
 import unittest
+import pickle
+import sys
 
 from test import support
 
@@ -35,6 +37,9 @@ class Seq2(object):
 
 
 class OperatorTestCase:
+    def setUp(self):
+        sys.modules['operator'] = self.module
+
     def test_lt(self):
         operator = self.module
         self.assertRaises(TypeError, operator.lt)
@@ -396,6 +401,7 @@ class OperatorTestCase:
     def test_methodcaller(self):
         operator = self.module
         self.assertRaises(TypeError, operator.methodcaller)
+        self.assertRaises(TypeError, operator.methodcaller, 12)
         class A:
             def foo(self, *args, **kwds):
                 return args[0] + args[1]
@@ -491,5 +497,108 @@ class PyOperatorTestCase(OperatorTestCase, unittest.TestCase):
 class COperatorTestCase(OperatorTestCase, unittest.TestCase):
     module = c_operator
 
+
+class OperatorPickleTestCase:
+    def copy(self, obj, proto):
+        with support.swap_item(sys.modules, 'operator', self.module):
+            pickled = pickle.dumps(obj, proto)
+        with support.swap_item(sys.modules, 'operator', self.module2):
+            return pickle.loads(pickled)
+
+    def test_attrgetter(self):
+        attrgetter = self.module.attrgetter
+        attrgetter = self.module.attrgetter
+        class A:
+            pass
+        a = A()
+        a.x = 'X'
+        a.y = 'Y'
+        a.z = 'Z'
+        a.t = A()
+        a.t.u = A()
+        a.t.u.v = 'V'
+        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+            with self.subTest(proto=proto):
+                f = attrgetter('x')
+                f2 = self.copy(f, proto)
+                self.assertEqual(repr(f2), repr(f))
+                self.assertEqual(f2(a), f(a))
+                # multiple gets
+                f = attrgetter('x', 'y', 'z')
+                f2 = self.copy(f, proto)
+                self.assertEqual(repr(f2), repr(f))
+                self.assertEqual(f2(a), f(a))
+                # recursive gets
+                f = attrgetter('t.u.v')
+                f2 = self.copy(f, proto)
+                self.assertEqual(repr(f2), repr(f))
+                self.assertEqual(f2(a), f(a))
+
+    def test_itemgetter(self):
+        itemgetter = self.module.itemgetter
+        a = 'ABCDE'
+        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+            with self.subTest(proto=proto):
+                f = itemgetter(2)
+                f2 = self.copy(f, proto)
+                self.assertEqual(repr(f2), repr(f))
+                self.assertEqual(f2(a), f(a))
+                # multiple gets
+                f = itemgetter(2, 0, 4)
+                f2 = self.copy(f, proto)
+                self.assertEqual(repr(f2), repr(f))
+                self.assertEqual(f2(a), f(a))
+
+    def test_methodcaller(self):
+        methodcaller = self.module.methodcaller
+        class A:
+            def foo(self, *args, **kwds):
+                return args[0] + args[1]
+            def bar(self, f=42):
+                return f
+            def baz(*args, **kwds):
+                return kwds['name'], kwds['self']
+        a = A()
+        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+            with self.subTest(proto=proto):
+                f = methodcaller('bar')
+                f2 = self.copy(f, proto)
+                self.assertEqual(repr(f2), repr(f))
+                self.assertEqual(f2(a), f(a))
+                # positional args
+                f = methodcaller('foo', 1, 2)
+                f2 = self.copy(f, proto)
+                self.assertEqual(repr(f2), repr(f))
+                self.assertEqual(f2(a), f(a))
+                # keyword args
+                f = methodcaller('bar', f=5)
+                f2 = self.copy(f, proto)
+                self.assertEqual(repr(f2), repr(f))
+                self.assertEqual(f2(a), f(a))
+                f = methodcaller('baz', self='eggs', name='spam')
+                f2 = self.copy(f, proto)
+                # Can't test repr consistently with multiple keyword args
+                self.assertEqual(f2(a), f(a))
+
+class PyPyOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase):
+    module = py_operator
+    module2 = py_operator
+
+@unittest.skipUnless(c_operator, 'requires _operator')
+class PyCOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase):
+    module = py_operator
+    module2 = c_operator
+
+@unittest.skipUnless(c_operator, 'requires _operator')
+class CPyOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase):
+    module = c_operator
+    module2 = py_operator
+
+@unittest.skipUnless(c_operator, 'requires _operator')
+class CCOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase):
+    module = c_operator
+    module2 = c_operator
+
+
 if __name__ == "__main__":
     unittest.main()
index 71b35c3eeb5a8d20e7d068d98792437396a43867..15ab1c8f49b54d3324c759c24f366c298029295d 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -52,6 +52,10 @@ Core and Builtins
 Library
 -------
 
+- Issue #22955: attrgetter, itemgetter and methodcaller objects in the operator
+  module now support pickling.  Added readable and evaluable repr for these
+  objects.  Based on patch by Josh Rosenberg.
+
 - Issue #22107: tempfile.gettempdir() and tempfile.mkdtemp() now try again
   when a directory with the chosen name already exists on Windows as well as
   on Unix.  tempfile.mkstemp() now fails early if parent directory is not
index 8f524a6449b8f8e1cea6b66b440f751b242e8825..9e4db58d76779e9bdb4cdad688388980d762511f 100644 (file)
@@ -485,6 +485,41 @@ itemgetter_call(itemgetterobject *ig, PyObject *args, PyObject *kw)
     return result;
 }
 
+static PyObject *
+itemgetter_repr(itemgetterobject *ig)
+{
+    PyObject *repr;
+    const char *reprfmt;
+
+    int status = Py_ReprEnter((PyObject *)ig);
+    if (status != 0) {
+        if (status < 0)
+            return NULL;
+        return PyUnicode_FromFormat("%s(...)", Py_TYPE(ig)->tp_name);
+    }
+
+    reprfmt = ig->nitems == 1 ? "%s(%R)" : "%s%R";
+    repr = PyUnicode_FromFormat(reprfmt, Py_TYPE(ig)->tp_name, ig->item);
+    Py_ReprLeave((PyObject *)ig);
+    return repr;
+}
+
+static PyObject *
+itemgetter_reduce(itemgetterobject *ig)
+{
+    if (ig->nitems == 1)
+        return Py_BuildValue("O(O)", Py_TYPE(ig), ig->item);
+    return PyTuple_Pack(2, Py_TYPE(ig), ig->item);
+}
+
+PyDoc_STRVAR(reduce_doc, "Return state information for pickling");
+
+static PyMethodDef itemgetter_methods[] = {
+    {"__reduce__", (PyCFunction)itemgetter_reduce, METH_NOARGS,
+     reduce_doc},
+    {NULL}
+};
+
 PyDoc_STRVAR(itemgetter_doc,
 "itemgetter(item, ...) --> itemgetter object\n\
 \n\
@@ -503,7 +538,7 @@ static PyTypeObject itemgetter_type = {
     0,                                  /* tp_getattr */
     0,                                  /* tp_setattr */
     0,                                  /* tp_reserved */
-    0,                                  /* tp_repr */
+    (reprfunc)itemgetter_repr,          /* tp_repr */
     0,                                  /* tp_as_number */
     0,                                  /* tp_as_sequence */
     0,                                  /* tp_as_mapping */
@@ -521,7 +556,7 @@ static PyTypeObject itemgetter_type = {
     0,                                  /* tp_weaklistoffset */
     0,                                  /* tp_iter */
     0,                                  /* tp_iternext */
-    0,                                  /* tp_methods */
+    itemgetter_methods,                 /* tp_methods */
     0,                                  /* tp_members */
     0,                                  /* tp_getset */
     0,                                  /* tp_base */
@@ -737,6 +772,91 @@ attrgetter_call(attrgetterobject *ag, PyObject *args, PyObject *kw)
     return result;
 }
 
+static PyObject *
+dotjoinattr(PyObject *attr, PyObject **attrsep)
+{
+    if (PyTuple_CheckExact(attr)) {
+        if (*attrsep == NULL) {
+            *attrsep = PyUnicode_FromString(".");
+            if (*attrsep == NULL)
+                return NULL;
+        }
+        return PyUnicode_Join(*attrsep, attr);
+    } else {
+        Py_INCREF(attr);
+        return attr;
+    }
+}
+
+static PyObject *
+attrgetter_args(attrgetterobject *ag)
+{
+    Py_ssize_t i;
+    PyObject *attrsep = NULL;
+    PyObject *attrstrings = PyTuple_New(ag->nattrs);
+    if (attrstrings == NULL)
+        return NULL;
+
+    for (i = 0; i < ag->nattrs; ++i) {
+        PyObject *attr = PyTuple_GET_ITEM(ag->attr, i);
+        PyObject *attrstr = dotjoinattr(attr, &attrsep);
+        if (attrstr == NULL) {
+            Py_XDECREF(attrsep);
+            Py_DECREF(attrstrings);
+            return NULL;
+        }
+        PyTuple_SET_ITEM(attrstrings, i, attrstr);
+    }
+    Py_XDECREF(attrsep);
+    return attrstrings;
+}
+
+static PyObject *
+attrgetter_repr(attrgetterobject *ag)
+{
+    PyObject *repr = NULL;
+    int status = Py_ReprEnter((PyObject *)ag);
+    if (status != 0) {
+        if (status < 0)
+            return NULL;
+        return PyUnicode_FromFormat("%s(...)", Py_TYPE(ag)->tp_name);
+    }
+
+    if (ag->nattrs == 1) {
+        PyObject *attrsep = NULL;
+        PyObject *attr = dotjoinattr(PyTuple_GET_ITEM(ag->attr, 0), &attrsep);
+        if (attr != NULL)
+            repr = PyUnicode_FromFormat("%s(%R)", Py_TYPE(ag)->tp_name, attr);
+        Py_XDECREF(attrsep);
+    }
+    else {
+        PyObject *attrstrings = attrgetter_args(ag);
+        if (attrstrings != NULL) {
+            repr = PyUnicode_FromFormat("%s%R",
+                                        Py_TYPE(ag)->tp_name, attrstrings);
+            Py_DECREF(attrstrings);
+        }
+    }
+    Py_ReprLeave((PyObject *)ag);
+    return repr;
+}
+
+static PyObject *
+attrgetter_reduce(attrgetterobject *ag)
+{
+    PyObject *attrstrings = attrgetter_args(ag);
+    if (attrstrings == NULL)
+        return NULL;
+
+    return Py_BuildValue("ON", Py_TYPE(ag), attrstrings);
+}
+
+static PyMethodDef attrgetter_methods[] = {
+    {"__reduce__", (PyCFunction)attrgetter_reduce, METH_NOARGS,
+     reduce_doc},
+    {NULL}
+};
+
 PyDoc_STRVAR(attrgetter_doc,
 "attrgetter(attr, ...) --> attrgetter object\n\
 \n\
@@ -757,7 +877,7 @@ static PyTypeObject attrgetter_type = {
     0,                                  /* tp_getattr */
     0,                                  /* tp_setattr */
     0,                                  /* tp_reserved */
-    0,                                  /* tp_repr */
+    (reprfunc)attrgetter_repr,          /* tp_repr */
     0,                                  /* tp_as_number */
     0,                                  /* tp_as_sequence */
     0,                                  /* tp_as_mapping */
@@ -775,7 +895,7 @@ static PyTypeObject attrgetter_type = {
     0,                                  /* tp_weaklistoffset */
     0,                                  /* tp_iter */
     0,                                  /* tp_iternext */
-    0,                                  /* tp_methods */
+    attrgetter_methods,                 /* tp_methods */
     0,                                  /* tp_members */
     0,                                  /* tp_getset */
     0,                                  /* tp_base */
@@ -813,6 +933,13 @@ methodcaller_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
         return NULL;
     }
 
+    name = PyTuple_GET_ITEM(args, 0);
+    if (!PyUnicode_Check(name)) {
+        PyErr_SetString(PyExc_TypeError,
+                        "method name must be a string");
+        return NULL;
+    }
+
     /* create methodcallerobject structure */
     mc = PyObject_GC_New(methodcallerobject, &methodcaller_type);
     if (mc == NULL)
@@ -825,8 +952,8 @@ methodcaller_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
     }
     mc->args = newargs;
 
-    name = PyTuple_GET_ITEM(args, 0);
     Py_INCREF(name);
+    PyUnicode_InternInPlace(&name);
     mc->name = name;
 
     Py_XINCREF(kwds);
@@ -869,6 +996,142 @@ methodcaller_call(methodcallerobject *mc, PyObject *args, PyObject *kw)
     return result;
 }
 
+static PyObject *
+methodcaller_repr(methodcallerobject *mc)
+{
+    PyObject *argreprs, *repr = NULL, *sep, *joinedargreprs;
+    Py_ssize_t numtotalargs, numposargs, numkwdargs, i;
+    int status = Py_ReprEnter((PyObject *)mc);
+    if (status != 0) {
+        if (status < 0)
+            return NULL;
+        return PyUnicode_FromFormat("%s(...)", Py_TYPE(mc)->tp_name);
+    }
+
+    if (mc->kwds != NULL) {
+        numkwdargs = PyDict_Size(mc->kwds);
+        if (numkwdargs < 0) {
+            Py_ReprLeave((PyObject *)mc);
+            return NULL;
+        }
+    } else {
+        numkwdargs = 0;
+    }
+
+    numposargs = PyTuple_GET_SIZE(mc->args);
+    numtotalargs = numposargs + numkwdargs;
+
+    if (numtotalargs == 0) {
+        repr = PyUnicode_FromFormat("%s(%R)", Py_TYPE(mc)->tp_name, mc->name);
+        Py_ReprLeave((PyObject *)mc);
+        return repr;
+    }
+
+    argreprs = PyTuple_New(numtotalargs);
+    if (argreprs == NULL) {
+        Py_ReprLeave((PyObject *)mc);
+        return NULL;
+    }
+
+    for (i = 0; i < numposargs; ++i) {
+        PyObject *onerepr = PyObject_Repr(PyTuple_GET_ITEM(mc->args, i));
+        if (onerepr == NULL)
+            goto done;
+        PyTuple_SET_ITEM(argreprs, i, onerepr);
+    }
+
+    if (numkwdargs != 0) {
+        PyObject *key, *value;
+        Py_ssize_t pos = 0;
+        while (PyDict_Next(mc->kwds, &pos, &key, &value)) {
+            PyObject *onerepr = PyUnicode_FromFormat("%U=%R", key, value);
+            if (onerepr == NULL)
+                goto done;
+            if (i >= numtotalargs) {
+                i = -1;
+                break;
+            }
+            PyTuple_SET_ITEM(argreprs, i, onerepr);
+            ++i;
+        }
+        if (i != numtotalargs) {
+            PyErr_SetString(PyExc_RuntimeError,
+                            "keywords dict changed size during iteration");
+            goto done;
+        }
+    }
+
+    sep = PyUnicode_FromString(", ");
+    if (sep == NULL)
+        goto done;
+
+    joinedargreprs = PyUnicode_Join(sep, argreprs);
+    Py_DECREF(sep);
+    if (joinedargreprs == NULL)
+        goto done;
+
+    repr = PyUnicode_FromFormat("%s(%R, %U)", Py_TYPE(mc)->tp_name,
+                                mc->name, joinedargreprs);
+    Py_DECREF(joinedargreprs);
+
+done:
+    Py_DECREF(argreprs);
+    Py_ReprLeave((PyObject *)mc);
+    return repr;
+}
+
+static PyObject *
+methodcaller_reduce(methodcallerobject *mc)
+{
+    PyObject *newargs;
+    if (!mc->kwds || PyDict_Size(mc->kwds) == 0) {
+        Py_ssize_t i;
+        Py_ssize_t callargcount = PyTuple_GET_SIZE(mc->args);
+        newargs = PyTuple_New(1 + callargcount);
+        if (newargs == NULL)
+            return NULL;
+        Py_INCREF(mc->name);
+        PyTuple_SET_ITEM(newargs, 0, mc->name);
+        for (i = 0; i < callargcount; ++i) {
+            PyObject *arg = PyTuple_GET_ITEM(mc->args, i);
+            Py_INCREF(arg);
+            PyTuple_SET_ITEM(newargs, i + 1, arg);
+        }
+        return Py_BuildValue("ON", Py_TYPE(mc), newargs);
+    }
+    else {
+        PyObject *functools;
+        PyObject *partial;
+        PyObject *constructor;
+        _Py_IDENTIFIER(partial);
+        functools = PyImport_ImportModule("functools");
+        if (!functools)
+            return NULL;
+        partial = _PyObject_GetAttrId(functools, &PyId_partial);
+        Py_DECREF(functools);
+        if (!partial)
+            return NULL;
+        newargs = PyTuple_New(2);
+        if (newargs == NULL) {
+            Py_DECREF(partial);
+            return NULL;
+        }
+        Py_INCREF(Py_TYPE(mc));
+        PyTuple_SET_ITEM(newargs, 0, (PyObject *)Py_TYPE(mc));
+        Py_INCREF(mc->name);
+        PyTuple_SET_ITEM(newargs, 1, mc->name);
+        constructor = PyObject_Call(partial, newargs, mc->kwds);
+        Py_DECREF(newargs);
+        Py_DECREF(partial);
+        return Py_BuildValue("NO", constructor, mc->args);
+    }
+}
+
+static PyMethodDef methodcaller_methods[] = {
+    {"__reduce__", (PyCFunction)methodcaller_reduce, METH_NOARGS,
+     reduce_doc},
+    {NULL}
+};
 PyDoc_STRVAR(methodcaller_doc,
 "methodcaller(name, ...) --> methodcaller object\n\
 \n\
@@ -888,7 +1151,7 @@ static PyTypeObject methodcaller_type = {
     0,                                  /* tp_getattr */
     0,                                  /* tp_setattr */
     0,                                  /* tp_reserved */
-    0,                                  /* tp_repr */
+    (reprfunc)methodcaller_repr,        /* tp_repr */
     0,                                  /* tp_as_number */
     0,                                  /* tp_as_sequence */
     0,                                  /* tp_as_mapping */
@@ -906,7 +1169,7 @@ static PyTypeObject methodcaller_type = {
     0,                                  /* tp_weaklistoffset */
     0,                                  /* tp_iter */
     0,                                  /* tp_iternext */
-    0,                                  /* tp_methods */
+    methodcaller_methods,               /* tp_methods */
     0,                                  /* tp_members */
     0,                                  /* tp_getset */
     0,                                  /* tp_base */