]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Various fixups (most suggested by Armin Rigo).
authorRaymond Hettinger <python@rcn.com>
Mon, 17 Nov 2003 16:42:33 +0000 (16:42 +0000)
committerRaymond Hettinger <python@rcn.com>
Mon, 17 Nov 2003 16:42:33 +0000 (16:42 +0000)
Include/setobject.h
Lib/test/test_set.py
Objects/setobject.c

index eeffa8ad46089a30a62e758aaa6aca49d3304fcf..6289f9c6d48974ad96393349d9c985eedb1551bd 100644 (file)
@@ -20,6 +20,12 @@ typedef struct {
 PyAPI_DATA(PyTypeObject) PySet_Type;
 PyAPI_DATA(PyTypeObject) PyFrozenSet_Type;
 
+
+#define PyAnySet_Check(ob) \
+       ((ob)->ob_type == &PySet_Type || (ob)->ob_type == &PyFrozenSet_Type || \
+         PyType_IsSubtype((ob)->ob_type, &PySet_Type) || \
+         PyType_IsSubtype((ob)->ob_type, &PyFrozenSet_Type))
+
 #ifdef __cplusplus
 }
 #endif
index 1edb2dd999381e08a0c796aca4816d00d0bcd59c..8329fd16312ea1203b75b20f2a5e4f5c12a8791c 100644 (file)
@@ -152,6 +152,13 @@ class TestJointOps(unittest.TestCase):
 class TestSet(TestJointOps):
     thetype = set
 
+    def test_init(self):
+        s = set()
+        s.__init__(self.word)
+        self.assertEqual(s, set(self.word))
+        s.__init__(self.otherword)
+        self.assertEqual(s, set(self.otherword))
+
     def test_hash(self):
         self.assertRaises(TypeError, hash, self.s)
 
@@ -252,10 +259,20 @@ class TestSet(TestJointOps):
             else:
                 self.assert_(c not in self.s)
 
+class SetSubclass(set):
+    pass
+
+class TestSetSubclass(TestSet):
+    thetype = SetSubclass
 
 class TestFrozenSet(TestJointOps):
     thetype = frozenset
 
+    def test_init(self):
+        s = frozenset()
+        s.__init__(self.word)
+        self.assertEqual(s, frozenset())
+
     def test_hash(self):
         self.assertEqual(hash(frozenset('abcdeb')), hash(frozenset('ebecda')))
 
@@ -273,6 +290,12 @@ class TestFrozenSet(TestJointOps):
         f = frozenset('abcdcda')
         self.assertEqual(hash(f), hash(f))
 
+class FrozenSetSubclass(frozenset):
+    pass
+
+class TestFrozenSetSubclass(TestFrozenSet):
+    thetype = FrozenSetSubclass
+
 # Tests taken from test_sets.py =============================================
 
 empty_set = set()
@@ -1137,7 +1160,9 @@ def test_main(verbose=None):
     from test import test_sets
     test_classes = (
         TestSet,
+        TestSetSubclass,
         TestFrozenSet,
+        TestFrozenSetSubclass,
         TestSetOfSets,
         TestExceptionPropagation,
         TestBasicOpsEmpty,
index 61ba8539b664f08090a9e97b5fd360c8a9820175..7ad8af06fb0628d7d1fb86f57af12ace88645ede 100644 (file)
@@ -12,7 +12,6 @@
 /* Fast access macros */ 
 
 #define DICT_CONTAINS(d, k)  (d->ob_type->tp_as_sequence->sq_contains(d, k))
-#define IS_SET(so)     (so->ob_type == &PySet_Type || so->ob_type == &PyFrozenSet_Type)
 
 /* set object **********************************************************/
 
@@ -42,8 +41,6 @@ make_new_set(PyTypeObject *type, PyObject *iterable)
                        Py_DECREF(it);
                        Py_DECREF(data);
                        Py_DECREF(item);
-                       PyErr_SetString(PyExc_TypeError,
-                                       "all set entries must be immutable");
                        return NULL;
                 } 
                Py_DECREF(item);
@@ -67,7 +64,7 @@ make_new_set(PyTypeObject *type, PyObject *iterable)
 }
 
 static PyObject *
-set_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+frozenset_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
 {
        PyObject *iterable = NULL;
 
@@ -76,6 +73,14 @@ set_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
        return make_new_set(type, iterable);
 }
 
+static PyObject *
+set_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+{
+       PyObject *iterable = NULL;
+
+       return make_new_set(type, NULL);
+}
+
 static void
 set_dealloc(PySetObject *so)
 {
@@ -139,6 +144,8 @@ set_union(PySetObject *so, PyObject *other)
        PyObject *item, *data, *it;
 
        result = (PySetObject *)set_copy(so);
+       if (result == NULL)
+               return NULL;
        it = PyObject_GetIter(other);
        if (it == NULL) {
                Py_DECREF(result);
@@ -150,8 +157,6 @@ set_union(PySetObject *so, PyObject *other)
                        Py_DECREF(it);
                        Py_DECREF(result);
                        Py_DECREF(item);
-                       PyErr_SetString(PyExc_TypeError,
-                                       "all set entries must be immutable");
                        return NULL;
                 } 
                Py_DECREF(item);
@@ -183,8 +188,6 @@ set_union_update(PySetObject *so, PyObject *other)
                 if (PyDict_SetItem(data, item, Py_True) == -1) {
                        Py_DECREF(it);
                        Py_DECREF(item);
-                       PyErr_SetString(PyExc_TypeError,
-                                       "all set entries must be immutable");
                        return NULL;
                 } 
                Py_DECREF(item);
@@ -201,7 +204,7 @@ PyDoc_STRVAR(union_update_doc,
 static PyObject *
 set_or(PySetObject *so, PyObject *other)
 {
-       if (!IS_SET(so) || !IS_SET(other)) {
+       if (!PyAnySet_Check(so) || !PyAnySet_Check(other)) {
                Py_INCREF(Py_NotImplemented);
                return Py_NotImplemented;
        }
@@ -213,7 +216,7 @@ set_ior(PySetObject *so, PyObject *other)
 {
        PyObject *result;
 
-       if (!IS_SET(other)) {
+       if (!PyAnySet_Check(other)) {
                Py_INCREF(Py_NotImplemented);
                return Py_NotImplemented;
        }
@@ -249,8 +252,6 @@ set_intersection(PySetObject *so, PyObject *other)
                                Py_DECREF(it);
                                Py_DECREF(result);
                                Py_DECREF(item);
-                               PyErr_SetString(PyExc_TypeError,
-                                               "all set entries must be immutable");
                                return NULL;
                        }
                }
@@ -291,8 +292,6 @@ set_intersection_update(PySetObject *so, PyObject *other)
                                Py_DECREF(newdict);
                                Py_DECREF(it);
                                Py_DECREF(item);
-                               PyErr_SetString(PyExc_TypeError,
-                                               "all set entries must be immutable");
                                return NULL;
                        }
                }
@@ -315,7 +314,7 @@ PyDoc_STRVAR(intersection_update_doc,
 static PyObject *
 set_and(PySetObject *so, PyObject *other)
 {
-       if (!IS_SET(so) || !IS_SET(other)) {
+       if (!PyAnySet_Check(so) || !PyAnySet_Check(other)) {
                Py_INCREF(Py_NotImplemented);
                return Py_NotImplemented;
        }
@@ -327,7 +326,7 @@ set_iand(PySetObject *so, PyObject *other)
 {
        PyObject *result;
 
-       if (!IS_SET(other)) {
+       if (!PyAnySet_Check(other)) {
                Py_INCREF(Py_NotImplemented);
                return Py_NotImplemented;
        }
@@ -416,7 +415,7 @@ PyDoc_STRVAR(difference_update_doc,
 static PyObject *
 set_sub(PySetObject *so, PyObject *other)
 {
-       if (!IS_SET(so) || !IS_SET(other)) {
+       if (!PyAnySet_Check(so) || !PyAnySet_Check(other)) {
                Py_INCREF(Py_NotImplemented);
                return Py_NotImplemented;
        }
@@ -428,7 +427,7 @@ set_isub(PySetObject *so, PyObject *other)
 {
        PyObject *result;
 
-       if (!IS_SET(other)) {
+       if (!PyAnySet_Check(other)) {
                Py_INCREF(Py_NotImplemented);
                return Py_NotImplemented;
        }
@@ -475,8 +474,6 @@ set_symmetric_difference(PySetObject *so, PyObject *other)
                                Py_DECREF(otherset);
                                Py_DECREF(result);
                                Py_DECREF(item);
-                               PyErr_SetString(PyExc_TypeError,
-                                               "all set entries must be immutable");
                                return NULL;
                        } 
                }
@@ -506,7 +503,7 @@ set_symmetric_difference_update(PySetObject *so, PyObject *other)
 
        if (PyDict_Check(other))
                otherdata = other;
-       else if (IS_SET(other))
+       else if (PyAnySet_Check(other))
                otherdata = ((PySetObject *)other)->data;
        else {
                otherset = (PySetObject *)make_new_set(so->ob_type, other);
@@ -525,8 +522,6 @@ set_symmetric_difference_update(PySetObject *so, PyObject *other)
                                Py_XDECREF(otherset);
                                Py_DECREF(it);
                                Py_DECREF(item);
-                               PyErr_SetString(PyExc_TypeError,
-                                               "all set entries must be immutable");
                                return NULL;
                        }
                } else {
@@ -534,8 +529,6 @@ set_symmetric_difference_update(PySetObject *so, PyObject *other)
                                Py_XDECREF(otherset);
                                Py_DECREF(it);
                                Py_DECREF(item);
-                               PyErr_SetString(PyExc_TypeError,
-                                               "all set entries must be immutable");
                                return NULL;
                        }
                }
@@ -554,7 +547,7 @@ PyDoc_STRVAR(symmetric_difference_update_doc,
 static PyObject *
 set_xor(PySetObject *so, PyObject *other)
 {
-       if (!IS_SET(so) || !IS_SET(other)) {
+       if (!PyAnySet_Check(so) || !PyAnySet_Check(other)) {
                Py_INCREF(Py_NotImplemented);
                return Py_NotImplemented;
        }
@@ -566,7 +559,7 @@ set_ixor(PySetObject *so, PyObject *other)
 {
        PyObject *result;
 
-       if (!IS_SET(other)) {
+       if (!PyAnySet_Check(other)) {
                Py_INCREF(Py_NotImplemented);
                return Py_NotImplemented;
        }
@@ -583,7 +576,7 @@ set_issubset(PySetObject *so, PyObject *other)
 {
        PyObject *otherdata, *it, *item;
 
-       if (!IS_SET(other)) {
+       if (!PyAnySet_Check(other)) {
                PyErr_SetString(PyExc_TypeError, "can only compare to a set");
                return NULL;
        }
@@ -604,6 +597,8 @@ set_issubset(PySetObject *so, PyObject *other)
                Py_DECREF(item);
        }
        Py_DECREF(it);
+       if (PyErr_Occurred()) 
+               return NULL;
        Py_RETURN_TRUE;
 }
 
@@ -612,7 +607,7 @@ PyDoc_STRVAR(issubset_doc, "Report whether another set contains this set.");
 static PyObject *
 set_issuperset(PySetObject *so, PyObject *other)
 {
-       if (!IS_SET(other)) {
+       if (!PyAnySet_Check(other)) {
                PyErr_SetString(PyExc_TypeError, "can only compare to a set");
                return NULL;
        }
@@ -653,20 +648,21 @@ frozenset_hash(PyObject *self)
                hash ^= PyObject_Hash(item);
                Py_DECREF(item);
        }
-       so->hash = hash;
        Py_DECREF(it);
+       if (PyErr_Occurred()) 
+               return -1;
+       so->hash = hash;
        return hash;
 }
 
 static PyObject *
 set_richcompare(PySetObject *v, PyObject *w, int op)
 {
-       /* XXX factor out is_set test */
-       if (op == Py_EQ && !IS_SET(w))
-               Py_RETURN_FALSE;
-       else if (op == Py_NE && !IS_SET(w))
-               Py_RETURN_TRUE;
-       if (!IS_SET(w)) {
+       if(!PyAnySet_Check(w)) {
+               if (op == Py_EQ)
+                       Py_RETURN_FALSE;
+               if (op == Py_NE)
+                       Py_RETURN_TRUE;
                PyErr_SetString(PyExc_TypeError, "can only compare to a set");
                return NULL;
        }
@@ -698,8 +694,12 @@ set_repr(PySetObject *so)
        PyObject *keys, *result, *listrepr;
 
        keys = PyDict_Keys(so->data);
+       if (keys == NULL)
+               return NULL;
        listrepr = PyObject_Repr(keys);
        Py_DECREF(keys);
+       if (listrepr == NULL)
+               return NULL;
 
        result = PyString_FromFormat("%s(%s)", so->ob_type->tp_name,
                PyString_AS_STRING(listrepr));
@@ -732,6 +732,8 @@ set_tp_print(PySetObject *so, FILE *fp, int flags)
        }
        Py_DECREF(it);
        fprintf(fp, "])");
+       if (PyErr_Occurred()) 
+               return -1;
        return 0;
 }
 
@@ -810,8 +812,10 @@ set_pop(PySetObject *so)
                return NULL;
        }
        Py_INCREF(key);
-       if (PyDict_DelItem(so->data, key) == -1)
-               PyErr_Clear();
+       if (PyDict_DelItem(so->data, key) == -1) {
+               Py_DECREF(key);
+               return NULL;
+       }
        return key;
 }
 
@@ -837,6 +841,28 @@ done:
 
 PyDoc_STRVAR(reduce_doc, "Return state information for pickling.");
 
+static int
+set_init(PySetObject *self, PyObject *args, PyObject *kwds)
+{
+       PyObject *iterable = NULL;
+       PyObject *result;
+
+       if (!PyAnySet_Check(self))
+               return -1;
+       if (!PyArg_UnpackTuple(args, self->ob_type->tp_name, 0, 1, &iterable))
+               return -1;
+       PyDict_Clear(self->data);
+       self->hash = -1;
+       if (iterable == NULL)
+               return 0;
+       result = set_union_update(self, iterable);
+       if (result != NULL) {
+               Py_DECREF(result);
+               return 0;
+       }
+       return -1;
+}
+
 static PySequenceMethods set_as_sequence = {
        (inquiry)set_len,               /* sq_length */
        0,                              /* sq_concat */
@@ -971,7 +997,7 @@ PyTypeObject PySet_Type = {
        0,                              /* tp_descr_get */
        0,                              /* tp_descr_set */
        0,                              /* tp_dictoffset */
-       0,                              /* tp_init */
+       (initproc)set_init,             /* tp_init */
        PyType_GenericAlloc,            /* tp_alloc */
        set_new,                        /* tp_new */
        PyObject_GC_Del,                /* tp_free */
@@ -1068,6 +1094,6 @@ PyTypeObject PyFrozenSet_Type = {
        0,                              /* tp_dictoffset */
        0,                              /* tp_init */
        PyType_GenericAlloc,            /* tp_alloc */
-       set_new,                        /* tp_new */
+       frozenset_new,                  /* tp_new */
        PyObject_GC_Del,                /* tp_free */
 };