]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
SF patch #421922: Implement rich comparison for dicts.
authorTim Peters <tim.peters@gmail.com>
Tue, 8 May 2001 04:38:29 +0000 (04:38 +0000)
committerTim Peters <tim.peters@gmail.com>
Tue, 8 May 2001 04:38:29 +0000 (04:38 +0000)
d1 == d2 and d1 != d2 now work even if the keys and values in d1 and d2
don't support comparisons other than ==, and testing dicts for equality
is faster now (especially when inequality obtains).

Lib/test/test_richcmp.py
Misc/NEWS
Objects/dictobject.c

index 7884c7ece8aef30987f5169a51df107761863554..4e7d45925d220431dac67ae5204c5906751d3fbb 100644 (file)
@@ -221,6 +221,33 @@ def recursion():
     check('not a==b')
     if verbose: print "recursion tests ok"
 
+def dicts():
+    # Verify that __eq__ and __ne__ work for dicts even if the keys and
+    # values don't support anything other than __eq__ and __ne__.  Complex
+    # numbers are a fine example of that.
+    import random
+    imag1a = {}
+    for i in range(50):
+        imag1a[random.randrange(100)*1j] = random.randrange(100)*1j
+    items = imag1a.items()
+    random.shuffle(items)
+    imag1b = {}
+    for k, v in items:
+        imag1b[k] = v
+    imag2 = imag1b.copy()
+    imag2[k] = v + 1.0
+    verify(imag1a == imag1a, "imag1a == imag1a should have worked")
+    verify(imag1a == imag1b, "imag1a == imag1b should have worked")
+    verify(imag2 == imag2, "imag2 == imag2 should have worked")
+    verify(imag1a != imag2, "imag1a != imag2 should have worked")
+    for op in "<", "<=", ">", ">=":
+        try:
+            eval("imag1a %s imag2" % op)
+        except TypeError:
+            pass
+        else:
+            raise TestFailed("expected TypeError from imag1a %s imag2" % op)
+
 def main():
     basic()
     tabulate()
@@ -229,5 +256,6 @@ def main():
     testvector()
     misbehavin()
     recursion()
+    dicts()
 
 main()
index 1f971cd40c94554dad02d2d604e59c161ce8bf6e..f2150d5c1d828f588b35c56acdd28cf694fe3a97 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -17,14 +17,16 @@ Core
 
 - The following functions were generalized to work nicely with iterator
   arguments:
-    map(), filter(), reduce()
+    map(), filter(), reduce(), zip()
     list(), tuple() (PySequence_Tuple() and PySequence_Fast() in C API)
     max(), min()
-    zip()
     .join() method of strings
     'x in y' and 'x not in y' (PySequence_Contains() in C API)
     operator.countOf() (PySequence_Count() in C API)
 
+- Comparing dictionary objects via == and != is faster, and now works even
+  if the keys and values don't support comparisons other than ==.
+
 
 What's New in Python 2.1 (final)?
 =================================
index 96d779d77741b2a6df4a886bb15fc09164e533a8..56cc08f35c47a044d8a35cd68ce3644e8c3e7dfb 100644 (file)
@@ -1047,6 +1047,76 @@ dict_compare(dictobject *a, dictobject *b)
        return res;
 }
 
+/* Return 1 if dicts equal, 0 if not, -1 if error.
+ * Gets out as soon as any difference is detected.
+ * Uses only Py_EQ comparison.
+ */
+static int
+dict_equal(dictobject *a, dictobject *b)
+{
+       int i;
+
+       if (a->ma_used != b->ma_used)
+               /* can't be equal if # of entries differ */
+               return 0;
+  
+       /* Same # of entries -- check all of 'em.  Exit early on any diff. */
+       for (i = 0; i < a->ma_size; i++) {
+               PyObject *aval = a->ma_table[i].me_value;
+               if (aval != NULL) {
+                       int cmp;
+                       PyObject *bval;
+                       PyObject *key = a->ma_table[i].me_key;
+                       /* temporarily bump aval's refcount to ensure it stays
+                          alive until we're done with it */
+                       Py_INCREF(aval);
+                       bval = PyDict_GetItem((PyObject *)b, key);
+                       if (bval == NULL) {
+                               Py_DECREF(aval);
+                               return 0;
+                       }
+                       cmp = PyObject_RichCompareBool(aval, bval, Py_EQ);
+                       Py_DECREF(aval);
+                       if (cmp <= 0)  /* error or not equal */
+                               return cmp;
+               }
+       }
+       return 1;
+ }
+
+static PyObject *
+dict_richcompare(PyObject *v, PyObject *w, int op)
+{
+       int cmp;
+       PyObject *res;
+
+       if (!PyDict_Check(v) || !PyDict_Check(w)) {
+               res = Py_NotImplemented;
+       }
+       else if (op == Py_EQ || op == Py_NE) {
+               cmp = dict_equal((dictobject *)v, (dictobject *)w);
+               if (cmp < 0)
+                       return NULL;
+               res = (cmp == (op == Py_EQ)) ? Py_True : Py_False;
+       }
+       else {
+               cmp = dict_compare((dictobject *)v, (dictobject *)w);
+               if (cmp < 0 && PyErr_Occurred())
+                       return NULL;
+               switch (op) {
+                       case Py_LT: cmp = cmp <  0; break;
+                       case Py_LE: cmp = cmp <= 0; break;
+                       case Py_GT: cmp = cmp >  0; break;
+                       case Py_GE: cmp = cmp >= 0; break;
+                       default:
+                               assert(!"op unexpected");
+               }
+               res = cmp ? Py_True : Py_False;
+       }
+       Py_INCREF(res);
+       return res;
+ }
+
 static PyObject *
 dict_has_key(register dictobject *mp, PyObject *args)
 {
@@ -1410,7 +1480,7 @@ PyTypeObject PyDict_Type = {
        (printfunc)dict_print,                  /* tp_print */
        (getattrfunc)dict_getattr,              /* tp_getattr */
        0,                                      /* tp_setattr */
-       (cmpfunc)dict_compare,                  /* tp_compare */
+       0,                                      /* tp_compare */
        (reprfunc)dict_repr,                    /* tp_repr */
        0,                                      /* tp_as_number */
        &dict_as_sequence,                      /* tp_as_sequence */
@@ -1425,7 +1495,7 @@ PyTypeObject PyDict_Type = {
        0,                                      /* tp_doc */
        (traverseproc)dict_traverse,            /* tp_traverse */
        (inquiry)dict_tp_clear,                 /* tp_clear */
-       0,                                      /* tp_richcompare */
+       dict_richcompare,                       /* tp_richcompare */
        0,                                      /* tp_weaklistoffset */
        (getiterfunc)dict_iter,                 /* tp_iter */
        0,                                      /* tp_iternext */