]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
* Beef-up testing of str.__contains__() and str.find().
authorRaymond Hettinger <python@rcn.com>
Sun, 20 Feb 2005 04:07:08 +0000 (04:07 +0000)
committerRaymond Hettinger <python@rcn.com>
Sun, 20 Feb 2005 04:07:08 +0000 (04:07 +0000)
* Speed-up "x in y" where x has more than one character.

The existing code made excessive calls to the expensive memcmp() function.
The new code uses memchr() to rapidly find a start point for memcmp().
In addition to knowing that the first character is a match, the new code
also checks that the last character is a match.  This significantly reduces
the incidence of false starts (saving memcmp() calls and making quadratic
behavior less likely).

Improves the timings on:
    python -m timeit -r7 -s"x='a'*1000" "'ab' in x"
    python -m timeit -r7 -s"x='a'*1000" "'bc' in x"

Once this code has proven itself, then string_find_internal() should refer
to it rather than running its own version.  Also, something similar may
apply to unicode objects.

Lib/test/string_tests.py
Objects/stringobject.c

index c8ed07cf60f38e454c3831b0679639aa2f0b8c85..0ce96189af16e4031e74d4d13959a734e3951346 100644 (file)
@@ -122,6 +122,30 @@ class CommonTest(unittest.TestCase):
         self.checkraises(TypeError, 'hello', 'find')
         self.checkraises(TypeError, 'hello', 'find', 42)
 
+        # For a variety of combinations,
+        #    verify that str.find() matches __contains__
+        #    and that the found substring is really at that location
+        charset = ['', 'a', 'b', 'c']
+        digits = 5
+        base = len(charset)
+        teststrings = set()
+        for i in xrange(base ** digits):
+            entry = []
+            for j in xrange(digits):
+                i, m = divmod(i, base)
+                entry.append(charset[m])
+            teststrings.add(''.join(entry))
+        for i in teststrings:
+            i = self.fixtype(i)
+            for j in teststrings:
+                loc = i.find(j)
+                r1 = (loc != -1)
+                r2 = j in i
+                if r1 != r2:
+                    self.assertEqual(r1, r2)
+                if loc != -1:
+                    self.assertEqual(i[loc:loc+len(j)], j)
+
     def test_rfind(self):
         self.checkequal(9,  'abcdefghiabc', 'rfind', 'abc')
         self.checkequal(12, 'abcdefghiabc', 'rfind', '')
index b90221a6b87ce2e91434837f0b29d0d1a6cee1f2..0cbf4390fc61a208c5318a8bd13ff91b60de7b71 100644 (file)
@@ -1002,8 +1002,12 @@ string_slice(register PyStringObject *a, register int i, register int j)
 static int
 string_contains(PyObject *a, PyObject *el)
 {
-       const char *lhs, *rhs, *end;
-       int size;
+       char *s = PyString_AS_STRING(a);
+       const char *sub = PyString_AS_STRING(el);
+       char *last;
+       int len_sub = PyString_GET_SIZE(el);
+       int shortsub;
+       char firstchar, lastchar;
 
        if (!PyString_CheckExact(el)) {
 #ifdef Py_USING_UNICODE
@@ -1016,20 +1020,29 @@ string_contains(PyObject *a, PyObject *el)
                        return -1;
                }
        }
-       size = PyString_GET_SIZE(el);
-       rhs = PyString_AS_STRING(el);
-       lhs = PyString_AS_STRING(a);
 
-       /* optimize for a single character */
-       if (size == 1)
-               return memchr(lhs, *rhs, PyString_GET_SIZE(a)) != NULL;
-
-       end = lhs + (PyString_GET_SIZE(a) - size);
-       while (lhs <= end) {
-               if (memcmp(lhs++, rhs, size) == 0)
+       if (len_sub == 0)
+               return 1;
+       /* last points to one char beyond the start of the rightmost 
+          substring.  When s<last, there is still room for a possible match
+          and s[0] through s[len_sub-1] will be in bounds.
+          shortsub is len_sub minus the last character which is checked
+          separately just before the memcmp().  That check helps prevent
+          false starts and saves the setup time for memcmp().
+       */
+       firstchar = sub[0];
+       shortsub = len_sub - 1;
+       lastchar = sub[shortsub];
+       last = s + PyString_GET_SIZE(a) - len_sub + 1;
+       while (s < last) {
+               s = memchr(s, firstchar, last-s);
+               if (s == NULL)
+                       return 0;
+               assert(s < last);
+               if (s[shortsub] == lastchar && memcmp(s, sub, shortsub) == 0)
                        return 1;
+               s++;
        }
-
        return 0;
 }