]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Extend temporary hashability to remove() and discard().
authorRaymond Hettinger <python@rcn.com>
Sat, 22 Nov 2003 03:55:23 +0000 (03:55 +0000)
committerRaymond Hettinger <python@rcn.com>
Sat, 22 Nov 2003 03:55:23 +0000 (03:55 +0000)
Brings the functionality back in line with sets.py.

Lib/test/test_set.py
Objects/setobject.c

index 3203d516938c1f5c55bd135aee567084c43d58cc..85f87f7d1222918946fe6b6dff9575e272e77937 100644 (file)
@@ -182,12 +182,22 @@ class TestSet(TestJointOps):
         self.assert_('a' not in self.s)
         self.assertRaises(KeyError, self.s.remove, 'Q')
         self.assertRaises(TypeError, self.s.remove, [])
+        s = self.thetype([frozenset(self.word)])
+        self.assert_(self.thetype(self.word) in s)
+        s.remove(self.thetype(self.word))
+        self.assert_(self.thetype(self.word) not in s)
+        self.assertRaises(KeyError, self.s.remove, self.thetype(self.word))
 
     def test_discard(self):
         self.s.discard('a')
         self.assert_('a' not in self.s)
         self.s.discard('Q')
         self.assertRaises(TypeError, self.s.discard, [])
+        s = self.thetype([frozenset(self.word)])
+        self.assert_(self.thetype(self.word) in s)
+        s.discard(self.thetype(self.word))
+        self.assert_(self.thetype(self.word) not in s)
+        s.discard(self.thetype(self.word))
 
     def test_pop(self):
         for i in xrange(len(self.s)):
index 2d77c7485a3b72f418c35b46bd81858299c25b4f..be73954b45466a1b4050d3a1b8a40e4ae20f7f42 100644 (file)
@@ -73,6 +73,21 @@ set_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
        return make_new_set(type, NULL);
 }
 
+static PyObject *
+frozenset_dict_wrapper(PyObject *d)
+{
+       PySetObject *w;
+
+       assert(PyDict_Check(d));
+       w = (PySetObject *)make_new_set(&PyFrozenSet_Type, NULL);
+       if (w == NULL)
+               return NULL;
+       Py_DECREF(w->data);
+       Py_INCREF(d);
+       w->data = d;
+       return (PyObject *)w;
+}
+
 static void
 set_dealloc(PySetObject *so)
 {
@@ -104,20 +119,16 @@ set_len(PySetObject *so)
 static int
 set_contains(PySetObject *so, PyObject *key)
 {
-       PyObject *olddict;
-       PySetObject *tmp;
+       PyObject *tmp;
        int result;
 
        result = PySequence_Contains(so->data, key);
        if (result == -1 && PyType_IsSubtype(key->ob_type, &PySet_Type)) {
                PyErr_Clear();
-               tmp = (PySetObject *)make_new_set(&PyFrozenSet_Type, NULL);
+               tmp = frozenset_dict_wrapper(((PySetObject *)(key))->data);
                if (tmp == NULL)
                        return -1;
-               olddict = tmp->data;
-               tmp->data = ((PySetObject *)(key))->data;
-               result = PySequence_Contains(so->data, (PyObject *)tmp);
-               tmp->data = olddict;
+               result = PySequence_Contains(so->data, tmp);
                Py_DECREF(tmp);
        }
        return result;
@@ -820,8 +831,21 @@ This has no effect if the element is already present.");
 static PyObject *
 set_remove(PySetObject *so, PyObject *item)
 {
-       if (PyDict_DelItem(so->data, item) == -1)
-               return NULL;
+       PyObject *tmp;
+
+       if (PyDict_DelItem(so->data, item) == -1) {
+               if (!PyType_IsSubtype(item->ob_type, &PySet_Type)) 
+                       return NULL;
+               PyErr_Clear();
+               tmp = frozenset_dict_wrapper(((PySetObject *)(item))->data);
+               if (tmp == NULL)
+                       return NULL;
+               if (PyDict_DelItem(so->data, tmp) == -1) {
+                       Py_DECREF(tmp);
+                       return NULL;
+               }
+               Py_DECREF(tmp);
+       }
        Py_INCREF(Py_None);
        return Py_None;
 }
@@ -834,11 +858,28 @@ If the element is not a member, raise a KeyError.");
 static PyObject *
 set_discard(PySetObject *so, PyObject *item)
 {
+       PyObject *tmp;
+
        if (PyDict_DelItem(so->data, item) == -1) {
                if  (PyErr_ExceptionMatches(PyExc_KeyError))
                        PyErr_Clear();
-               else
-                       return NULL;
+               else {
+                       if (!PyType_IsSubtype(item->ob_type, &PySet_Type)) 
+                               return NULL;
+                       PyErr_Clear();
+                       tmp = frozenset_dict_wrapper(((PySetObject *)(item))->data);
+                       if (tmp == NULL)
+                               return NULL;
+                       if (PyDict_DelItem(so->data, tmp) == -1) {
+                               if  (PyErr_ExceptionMatches(PyExc_KeyError))
+                                       PyErr_Clear();
+                               else {
+                                       Py_DECREF(tmp);
+                                       return NULL;
+                               }
+                       }
+                       Py_DECREF(tmp);
+               }
        }
        Py_INCREF(Py_None);
        return Py_None;