]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Reimplement PySequence_Contains() and instance_contains(), so they work
authorTim Peters <tim.peters@gmail.com>
Sat, 5 May 2001 21:05:01 +0000 (21:05 +0000)
committerTim Peters <tim.peters@gmail.com>
Sat, 5 May 2001 21:05:01 +0000 (21:05 +0000)
safely together and don't duplicate logic (the common logic was factored
out into new private API function _PySequence_IterContains()).
Visible change:
    some_complex_number  in  some_instance
no longer blows up if some_instance has __getitem__ but neither
__contains__ nor __iter__.  test_iter changed to ensure that remains true.

Include/abstract.h
Lib/test/test_iter.py
Objects/abstract.c
Objects/classobject.c

index d5f4a9978d4dcb95a484d97de26d2f5ece64392c..9082edb0b83fd849f283ae346cd781b562976608 100644 (file)
@@ -932,7 +932,17 @@ xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx*/
         expression: o.count(value).
        */
 
-     DL_IMPORT(int) PySequence_Contains(PyObject *o, PyObject *value);
+     DL_IMPORT(int) PySequence_Contains(PyObject *seq, PyObject *ob);
+       /*
+         Return -1 if error; 1 if ob in seq; 0 if ob not in seq.
+         Use __contains__ if possible, else _PySequence_IterContains().
+       */
+
+     DL_IMPORT(int) _PySequence_IterContains(PyObject *seq, PyObject *ob);
+       /*
+         Return -1 if error; 1 if ob in seq; 0 if ob not in seq.
+         Always uses the iteration protocol, and only Py_EQ comparisons.
+       */
 
 /* For DLL-level backwards compatibility */
 #undef PySequence_In
index 7d15e1cfb8f469a097a499d03c9e69cde32916ed..22a7c4460d4b102c66045eef1ee8649fa6902ff9 100644 (file)
@@ -474,24 +474,12 @@ class TestCase(unittest.TestCase):
 
     # Test iterators with 'x in y' and 'x not in y'.
     def test_in_and_not_in(self):
-        sc5 = IteratingSequenceClass(5)
-        for i in range(5):
-            self.assert_(i in sc5)
-        # CAUTION:  This test fails on 3-12j if sc5 is SequenceClass(5)
-        # instead, with:
-        #     TypeError: cannot compare complex numbers using <, <=, >, >=
-        # The trail leads back to instance_contains() in classobject.c,
-        # under comment:
-        #     /* fall back to previous behavior */
-        # IteratingSequenceClass(5) avoids the same problem only because
-        # it lacks __getitem__:  instance_contains *tries* to do a wrong
-        # thing with it too, but aborts with an AttributeError the first
-        # time it calls instance_item(); PySequence_Contains() then catches
-        # that and clears it, and tries the iterator-based "contains"
-        # instead.  But this is hanging together by a thread.
-        for i in "abc", -1, 5, 42.42, (3, 4), [], {1: 1}, 3-12j, sc5:
-            self.assert_(i not in sc5)
-        del sc5
+        for sc5 in IteratingSequenceClass(5), SequenceClass(5):
+            for i in range(5):
+                self.assert_(i in sc5)
+            for i in "abc", -1, 5, 42.42, (3, 4), [], {1: 1}, 3-12j, sc5:
+                self.assert_(i not in sc5)
+            del sc5
 
         self.assertRaises(TypeError, lambda: 3 in 12)
         self.assertRaises(TypeError, lambda: 3 not in map)
index 21c1ef1de46df1a171549edacdf7e1fb52996050..c1d7789747819f76a42dc9a7fcef73b82cbf2091 100644 (file)
@@ -1381,29 +1381,14 @@ Fail:
        return -1;
 }
 
-/* Return -1 if error; 1 if v in w; 0 if v not in w. */
+/* Return -1 if error; 1 if ob in seq; 0 if ob not in seq.
+ * Always uses the iteration protocol, and only Py_EQ comparison.
+ */
 int
-PySequence_Contains(PyObject *w, PyObject *v) /* v in w */
+_PySequence_IterContains(PyObject *seq, PyObject *ob)
 {
-       PyObject *it;  /* iter(w) */
        int result;
-
-       if (PyType_HasFeature(w->ob_type, Py_TPFLAGS_HAVE_SEQUENCE_IN)) {
-               PySequenceMethods *sq = w->ob_type->tp_as_sequence;
-               if (sq != NULL && sq->sq_contains != NULL) {
-                       result = (*sq->sq_contains)(w, v);
-                       if (result >= 0)
-                               return result;
-                       assert(PyErr_Occurred());
-                       if (PyErr_ExceptionMatches(PyExc_AttributeError))
-                               PyErr_Clear();
-                       else
-                               return result;
-               }
-       }
-       
-       /* Try exhaustive iteration. */
-       it = PyObject_GetIter(w);
+       PyObject *it = PyObject_GetIter(seq);
        if (it == NULL) {
                PyErr_SetString(PyExc_TypeError,
                        "'in' or 'not in' needs iterable right argument");
@@ -1417,7 +1402,7 @@ PySequence_Contains(PyObject *w, PyObject *v) /* v in w */
                        result = PyErr_Occurred() ? -1 : 0;
                        break;
                }
-               cmp = PyObject_RichCompareBool(v, item, Py_EQ);
+               cmp = PyObject_RichCompareBool(ob, item, Py_EQ);
                Py_DECREF(item);
                if (cmp == 0)
                        continue;
@@ -1428,6 +1413,20 @@ PySequence_Contains(PyObject *w, PyObject *v) /* v in w */
        return result;
 }
 
+/* Return -1 if error; 1 if ob in seq; 0 if ob not in seq.
+ * Use sq_contains if possible, else defer to _PySequence_IterContains().
+ */
+int
+PySequence_Contains(PyObject *seq, PyObject *ob)
+{
+       if (PyType_HasFeature(seq->ob_type, Py_TPFLAGS_HAVE_SEQUENCE_IN)) {
+               PySequenceMethods *sqm = seq->ob_type->tp_as_sequence;
+               if (sqm != NULL && sqm->sq_contains != NULL)
+                       return (*sqm->sq_contains)(seq, ob);
+       }
+       return _PySequence_IterContains(seq, ob);
+}
+
 /* Backwards compatibility */
 #undef PySequence_In
 int
index 2babbfbd63fb02a504e3c264fdca3f71b6ad454d..67732ca228913eafd9a09fbf9c8c6beab0f2fbff 100644 (file)
@@ -1131,11 +1131,15 @@ instance_ass_slice(PyInstanceObject *inst, int i, int j, PyObject *value)
        return 0;
 }
 
-static int instance_contains(PyInstanceObject *inst, PyObject *member)
+static int
+instance_contains(PyInstanceObject *inst, PyObject *member)
 {
        static PyObject *__contains__;
-       PyObject *func, *arg, *res;
-       int ret;
+       PyObject *func;
+
+       /* Try __contains__ first.
+        * If that can't be done, try iterator-based searching.
+        */
 
        if(__contains__ == NULL) {
                __contains__ = PyString_InternFromString("__contains__");
@@ -1143,45 +1147,34 @@ static int instance_contains(PyInstanceObject *inst, PyObject *member)
                        return -1;
        }
        func = instance_getattr(inst, __contains__);
-       if(func == NULL) {
-               /* fall back to previous behavior */
-               int i, cmp_res;
-
-               if(!PyErr_ExceptionMatches(PyExc_AttributeError))
+       if (func) {
+               PyObject *res;
+               int ret;
+               PyObject *arg = Py_BuildValue("(O)", member);
+               if(arg == NULL) {
+                       Py_DECREF(func);
                        return -1;
-               PyErr_Clear();
-               for(i=0;;i++) {
-                       PyObject *obj = instance_item(inst, i);
-                       int ret = 0;
-
-                       if(obj == NULL) {
-                               if(!PyErr_ExceptionMatches(PyExc_IndexError))
-                                       return -1;
-                               PyErr_Clear();
-                               return 0;
-                       }
-                       if(PyObject_Cmp(obj, member, &cmp_res) == -1)
-                               ret = -1;
-                       if(cmp_res == 0) 
-                               ret = 1;
-                       Py_DECREF(obj);
-                       if(ret)
-                               return ret;
                }
-       }
-       arg = Py_BuildValue("(O)", member);
-       if(arg == NULL) {
+               res = PyEval_CallObject(func, arg);
                Py_DECREF(func);
-               return -1;
+               Py_DECREF(arg);
+               if(res == NULL) 
+                       return -1;
+               ret = PyObject_IsTrue(res);
+               Py_DECREF(res);
+               return ret;
        }
-       res = PyEval_CallObject(func, arg);
-       Py_DECREF(func);
-       Py_DECREF(arg);
-       if(res == NULL) 
+
+       /* Couldn't find __contains__. */
+       if (PyErr_ExceptionMatches(PyExc_AttributeError)) {
+               /* Assume the failure was simply due to that there is no
+                * __contains__ attribute, and try iterating instead.
+                */
+               PyErr_Clear();
+               return _PySequence_IterContains((PyObject *)inst, member);
+       }
+       else
                return -1;
-       ret = PyObject_IsTrue(res);
-       Py_DECREF(res);
-       return ret;
 }
 
 static PySequenceMethods