]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
copy.py fixed to first lookup __copy__ from the instance being copied,
authorAnthony Baxter <anthonybaxter@gmail.com>
Tue, 25 Jan 2005 12:52:18 +0000 (12:52 +0000)
committerAnthony Baxter <anthonybaxter@gmail.com>
Tue, 25 Jan 2005 12:52:18 +0000 (12:52 +0000)
rather than only looking at the type - this was broken in 2.3.

Lib/copy.py
Lib/test/test_copy.py
Modules/_testcapimodule.c

index 31adfd331723afc8a18a0096d6fb53e21092f9f4..95e30f5231cfec672955e29dbf7ce43646d68de3 100644 (file)
@@ -62,6 +62,15 @@ except ImportError:
 
 __all__ = ["Error", "copy", "deepcopy"]
 
+def _getspecial(cls, name):
+    for basecls in cls.__mro__:
+        try:
+            return basecls.__dict__[name]
+        except:
+            pass
+    else:
+        return None
+
 def copy(x):
     """Shallow copy operation on arbitrary Python objects.
 
@@ -74,7 +83,7 @@ def copy(x):
     if copier:
         return copier(x)
 
-    copier = getattr(cls, "__copy__", None)
+    copier = _getspecial(cls, "__copy__")
     if copier:
         return copier(x)
 
@@ -90,6 +99,9 @@ def copy(x):
             if reductor:
                 rv = reductor()
             else:
+                copier = getattr(x, "__copy__", None)
+                if copier:
+                    return copier()
                 raise Error("un(shallow)copyable object of type %s" % cls)
 
     return _reconstruct(x, rv, 0)
@@ -185,9 +197,9 @@ def deepcopy(x, memo=None, _nil=[]):
         if issc:
             y = _deepcopy_atomic(x, memo)
         else:
-            copier = getattr(x, "__deepcopy__", None)
+            copier = _getspecial(cls, "__deepcopy__")
             if copier:
-                y = copier(memo)
+                y = copier(x, memo)
             else:
                 reductor = dispatch_table.get(cls)
                 if reductor:
@@ -201,6 +213,9 @@ def deepcopy(x, memo=None, _nil=[]):
                         if reductor:
                             rv = reductor()
                         else:
+                            copier = getattr(x, "__deepcopy__", None)
+                            if copier:
+                                return copier(memo)
                             raise Error(
                                 "un(deep)copyable object of type %s" % cls)
                 y = _reconstruct(x, rv, 1, memo)
index 3d44304db9cd57776dff2ac692c414e862c521ba..3484fa77dd5e0c263318c05dcfeb7e443f72d233 100644 (file)
@@ -166,8 +166,64 @@ class TestCopy(unittest.TestCase):
         x = C(42)
         self.assertEqual(copy.copy(x), x)
 
-    # The deepcopy() method
+    # tests for copying extension types, iff module trycopy is installed
+    def test_copy_classictype(self):
+        from _testcapi import make_copyable
+        x = make_copyable([23])
+        y = copy.copy(x)
+        self.assertEqual(x, y)
+        self.assertEqual(x.tag, y.tag)
+        self.assert_(x is not y)
+        self.assert_(x.tag is y.tag)
+
+    def test_deepcopy_classictype(self):
+        from _testcapi import make_copyable
+        x = make_copyable([23])
+        y = copy.deepcopy(x)
+        self.assertEqual(x, y)
+        self.assertEqual(x.tag, y.tag)
+        self.assert_(x is not y)
+        self.assert_(x.tag is not y.tag)
+
+    # regression tests for metaclass-confusion
+    def test_copy_metaclassconfusion(self):
+        class MyOwnError(copy.Error):
+            pass
+        class Meta(type):
+            def __copy__(cls):
+                raise MyOwnError("can't copy classes w/this metaclass")
+        class C:
+            __metaclass__ = Meta
+            def __init__(self, tag):
+                self.tag = tag
+            def __cmp__(self, other):
+                return -cmp(other, self.tag)
+        # the metaclass can forbid shallow copying of its classes
+        self.assertRaises(MyOwnError, copy.copy, C)
+        # check that there is no interference with instances
+        x = C(23)
+        self.assertEqual(copy.copy(x), x)
+
+    def test_deepcopy_metaclassconfusion(self):
+        class MyOwnError(copy.Error):
+            pass
+        class Meta(type):
+            def __deepcopy__(cls, memo):
+                raise MyOwnError("can't deepcopy classes w/this metaclass")
+        class C:
+            __metaclass__ = Meta
+            def __init__(self, tag):
+                self.tag = tag
+            def __cmp__(self, other):
+                return -cmp(other, self.tag)
+        # types are ALWAYS deepcopied atomically, no matter what
+        self.assertEqual(copy.deepcopy(C), C)
+        # check that there is no interference with instances
+        x = C(23)
+        self.assertEqual(copy.deepcopy(x), x)
+
 
+    # The deepcopy() method
     def test_deepcopy_basic(self):
         x = 42
         y = copy.deepcopy(x)
index fd16d5fa40fbbcf6bd43743f3b53a0b08b442848..5e8f7ad0acfcb48e4292b6a41afd0c80905b423c 100644 (file)
@@ -587,6 +587,169 @@ test_thread_state(PyObject *self, PyObject *args)
 }
 #endif
 
+/* a classic-type with copyable instances */
+
+typedef struct {
+       PyObject_HEAD
+       /* instance tag (a string). */
+       PyObject* tag;
+} CopyableObject;
+
+staticforward PyTypeObject Copyable_Type;
+
+#define Copyable_CheckExact(op) ((op)->ob_type == &Copyable_Type)
+
+/* -------------------------------------------------------------------- */
+
+/* copyable constructor and destructor */
+static PyObject*
+copyable_new(PyObject* tag)
+{
+       CopyableObject* self;
+
+       self = PyObject_New(CopyableObject, &Copyable_Type);
+       if (self == NULL)
+               return NULL;
+       Py_INCREF(tag);
+       self->tag = tag;
+       return (PyObject*) self;
+}
+
+static PyObject*
+copyable(PyObject* self, PyObject* args, PyObject* kw)
+{
+       PyObject* elem;
+       PyObject* tag;
+       if (!PyArg_ParseTuple(args, "O:Copyable", &tag))
+               return NULL;
+       elem = copyable_new(tag);
+       return elem;
+}
+
+static void
+copyable_dealloc(CopyableObject* self)
+{
+       /* discard attributes */
+       Py_DECREF(self->tag);
+       PyObject_Del(self);
+}
+
+/* copyable methods */
+
+static PyObject*
+copyable_copy(CopyableObject* self, PyObject* args)
+{
+       CopyableObject* copyable;
+       if (!PyArg_ParseTuple(args, ":__copy__"))
+               return NULL;
+       copyable = (CopyableObject*)copyable_new(self->tag);
+       if (!copyable)
+               return NULL;
+       return (PyObject*) copyable;
+}
+
+PyObject* _copy_deepcopy;
+
+static PyObject*
+copyable_deepcopy(CopyableObject* self, PyObject* args)
+{
+       CopyableObject* copyable = 0;
+       PyObject* memo;
+       PyObject* tag_copy;
+       if (!PyArg_ParseTuple(args, "O:__deepcopy__", &memo))
+               return NULL;
+
+       tag_copy = PyObject_CallFunctionObjArgs(_copy_deepcopy, self->tag, memo, NULL);
+
+       if(tag_copy) {
+               copyable = (CopyableObject*)copyable_new(tag_copy);
+               Py_DECREF(tag_copy);
+       }
+       return (PyObject*) copyable;
+}
+
+static PyObject*
+copyable_repr(CopyableObject* self)
+{
+       PyObject* repr;
+       char buffer[100];
+       
+       repr = PyString_FromString("<Copyable {");
+
+       PyString_ConcatAndDel(&repr, PyObject_Repr(self->tag));
+
+       sprintf(buffer, "} at %p>", self);
+       PyString_ConcatAndDel(&repr, PyString_FromString(buffer));
+
+       return repr;
+}
+
+static int
+copyable_compare(CopyableObject* obj1, CopyableObject* obj2)
+{
+       return PyObject_Compare(obj1->tag, obj2->tag);
+}
+
+static PyMethodDef copyable_methods[] = {
+       {"__copy__", (PyCFunction) copyable_copy, METH_VARARGS},
+       {"__deepcopy__", (PyCFunction) copyable_deepcopy, METH_VARARGS},
+       {NULL, NULL}
+};
+
+static PyObject*  
+copyable_getattr(CopyableObject* self, char* name)
+{
+       PyObject* res;
+       res = Py_FindMethod(copyable_methods, (PyObject*) self, name);
+       if (res)
+       return res;
+       PyErr_Clear();
+       if (strcmp(name, "tag") == 0) {
+       res = self->tag;
+       } else {
+               PyErr_SetString(PyExc_AttributeError, name);
+               return NULL;
+       }
+       if (!res)
+               return NULL;
+       Py_INCREF(res);
+       return res;
+}
+
+static int
+copyable_setattr(CopyableObject* self, const char* name, PyObject* value)
+{
+       if (value == NULL) {
+               PyErr_SetString(
+                       PyExc_AttributeError,
+                       "can't delete copyable attributes"
+                       );
+               return -1;
+       }
+       if (strcmp(name, "tag") == 0) {
+               Py_DECREF(self->tag);
+               self->tag = value;
+               Py_INCREF(self->tag);
+       } else {
+               PyErr_SetString(PyExc_AttributeError, name);
+               return -1;
+       }
+       return 0;
+}
+
+statichere PyTypeObject Copyable_Type = {
+       PyObject_HEAD_INIT(NULL)
+       0, "Copyable", sizeof(CopyableObject), 0,
+       /* methods */
+       (destructor)copyable_dealloc, /* tp_dealloc */
+       0, /* tp_print */
+       (getattrfunc)copyable_getattr, /* tp_getattr */
+       (setattrfunc)copyable_setattr, /* tp_setattr */
+       (cmpfunc)copyable_compare, /* tp_compare */
+       (reprfunc)copyable_repr, /* tp_repr */
+       0, /* tp_as_number */
+};
+
 static PyMethodDef TestMethods[] = {
        {"raise_exception",     raise_exception,                 METH_VARARGS},
        {"test_config",         (PyCFunction)test_config,        METH_NOARGS},
@@ -613,9 +776,11 @@ static PyMethodDef TestMethods[] = {
        {"test_u_code",         (PyCFunction)test_u_code,        METH_NOARGS},
 #endif
 #ifdef WITH_THREAD
-       {"_test_thread_state", (PyCFunction)test_thread_state, METH_VARARGS},
+       {"_test_thread_state",  (PyCFunction)test_thread_state, METH_VARARGS},
 #endif
+       {"make_copyable",       (PyCFunction) copyable,         METH_VARARGS},
        {NULL, NULL} /* sentinel */
+
 };
 
 #define AddSym(d, n, f, v) {PyObject *o = f(v); PyDict_SetItemString(d, n, o); Py_DECREF(o);}
@@ -624,6 +789,15 @@ PyMODINIT_FUNC
 init_testcapi(void)
 {
        PyObject *m;
+       PyObject *copy_module;
+
+
+       copy_module = PyImport_ImportModule("copy");
+       if(!copy_module)
+               return;
+       _copy_deepcopy = PyObject_GetAttrString(copy_module, "deepcopy");
+       Py_DECREF(copy_module);
+       Copyable_Type.ob_type = &PyType_Type;
 
        m = Py_InitModule("_testcapi", TestMethods);