]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
filterstring() and filterunicode() in Python/bltinmodule.c
authorWalter Dörwald <walter@livinglogic.de>
Tue, 4 Feb 2003 16:28:00 +0000 (16:28 +0000)
committerWalter Dörwald <walter@livinglogic.de>
Tue, 4 Feb 2003 16:28:00 +0000 (16:28 +0000)
blindly assumed that tp_as_sequence->sq_item always returns
a str or unicode object. This might fail with str or unicode
subclasses.

This patch checks whether the object returned from __getitem__
is a str/unicode object and raises a TypeError if not (and
the filter function returned true).

Furthermore the result for __getitem__ can be more than one
character long, so checks for enough memory have to be done.

Lib/test/test_builtin.py
Python/bltinmodule.c

index 55ea8d24faa050ff65da81e43fdf5c4002b1f7c0..6e13050e53cccf5551fdb5964ee31da3f6b0ab17 100644 (file)
@@ -367,6 +367,16 @@ class BuiltinTest(unittest.TestCase):
                 raise ValueError
         self.assertRaises(ValueError, filter, lambda x: x >="3", badstr("1234"))
 
+        class badstr2(str):
+            def __getitem__(self, index):
+                return 42
+        self.assertRaises(TypeError, filter, lambda x: x >=42, badstr2("1234"))
+
+        class weirdstr(str):
+            def __getitem__(self, index):
+                return weirdstr(2*str.__getitem__(self, index))
+        self.assertEqual(filter(lambda x: x>="33", weirdstr("1234")), "3344")
+
         if have_unicode:
             # test bltinmodule.c::filterunicode()
             self.assertEqual(filter(None, unicode("12")), unicode("12"))
@@ -374,6 +384,17 @@ class BuiltinTest(unittest.TestCase):
             self.assertRaises(TypeError, filter, 42, unicode("12"))
             self.assertRaises(ValueError, filter, lambda x: x >="3", badstr(unicode("1234")))
 
+            class badunicode(unicode):
+                def __getitem__(self, index):
+                    return 42
+            self.assertRaises(TypeError, filter, lambda x: x >=42, badunicode("1234"))
+
+            class weirdunicode(unicode):
+                def __getitem__(self, index):
+                    return weirdunicode(2*unicode.__getitem__(self, index))
+            self.assertEqual(
+                filter(lambda x: x>=unicode("33"), weirdunicode("1234")), unicode("3344"))
+
     def test_float(self):
         self.assertEqual(float(3.14), 3.14)
         self.assertEqual(float(314), 314.0)
index 466fab9a3c26884dbe0583cb9473b081fc47c92d..c273012aa11c3a0b29073a7596c03ce4c2d17cf2 100644 (file)
@@ -1892,6 +1892,7 @@ filterstring(PyObject *func, PyObject *strobj)
        PyObject *result;
        register int i, j;
        int len = PyString_Size(strobj);
+       int outlen = len;
 
        if (func == Py_None) {
                /* No character is ever false -- share input string */
@@ -1921,13 +1922,43 @@ filterstring(PyObject *func, PyObject *strobj)
                }
                ok = PyObject_IsTrue(good);
                Py_DECREF(good);
-               if (ok)
-                       PyString_AS_STRING((PyStringObject *)result)[j++] =
-                               PyString_AS_STRING((PyStringObject *)item)[0];
+               if (ok) {
+                       int reslen;
+                       if (!PyString_Check(item)) {
+                               PyErr_SetString(PyExc_TypeError, "can't filter str to str:"
+                                       " __getitem__ returned different type");
+                               Py_DECREF(item);
+                               goto Fail_1;
+                       }
+                       reslen = PyString_GET_SIZE(item);
+                       if (reslen == 1) {
+                               PyString_AS_STRING(result)[j++] =
+                                       PyString_AS_STRING(item)[0];
+                       } else {
+                               /* do we need more space? */
+                               int need = j + reslen + len-i-1;
+                               if (need > outlen) {
+                                       /* overallocate, to avoid reallocations */
+                                       if (need<2*outlen)
+                                               need = 2*outlen;
+                                       if (_PyString_Resize(&result, need)) {
+                                               Py_DECREF(item);
+                                               return NULL;
+                                       }
+                                       outlen = need;
+                               }
+                               memcpy(
+                                       PyString_AS_STRING(result) + j,
+                                       PyString_AS_STRING(item),
+                                       reslen
+                               );
+                               j += reslen;
+                       }
+               }
                Py_DECREF(item);
        }
 
-       if (j < len)
+       if (j < outlen)
                _PyString_Resize(&result, j);
 
        return result;
@@ -1946,6 +1977,7 @@ filterunicode(PyObject *func, PyObject *strobj)
        PyObject *result;
        register int i, j;
        int len = PyUnicode_GetSize(strobj);
+       int outlen = len;
 
        if (func == Py_None) {
                /* No character is ever false -- share input string */
@@ -1975,13 +2007,43 @@ filterunicode(PyObject *func, PyObject *strobj)
                }
                ok = PyObject_IsTrue(good);
                Py_DECREF(good);
-               if (ok)
-                       PyUnicode_AS_UNICODE((PyStringObject *)result)[j++] =
-                               PyUnicode_AS_UNICODE((PyStringObject *)item)[0];
+               if (ok) {
+                       int reslen;
+                       if (!PyUnicode_Check(item)) {
+                               PyErr_SetString(PyExc_TypeError, "can't filter unicode to unicode:"
+                                       " __getitem__ returned different type");
+                               Py_DECREF(item);
+                               goto Fail_1;
+                       }
+                       reslen = PyUnicode_GET_SIZE(item);
+                       if (reslen == 1) {
+                               PyUnicode_AS_UNICODE(result)[j++] =
+                                       PyUnicode_AS_UNICODE(item)[0];
+                       } else {
+                               /* do we need more space? */
+                               int need = j + reslen + len-i-1;
+                               if (need > outlen) {
+                                       /* overallocate, to avoid reallocations */
+                                       if (need<2*outlen)
+                                               need = 2*outlen;
+                                       if (PyUnicode_Resize(&result, need)) {
+                                               Py_DECREF(item);
+                                               return NULL;
+                                       }
+                                       outlen = need;
+                               }
+                               memcpy(
+                                       PyUnicode_AS_UNICODE(result) + j,
+                                       PyUnicode_AS_UNICODE(item),
+                                       reslen*sizeof(Py_UNICODE)
+                               );
+                               j += reslen;
+                       }
+               }
                Py_DECREF(item);
        }
 
-       if (j < len)
+       if (j < outlen)
                PyUnicode_Resize(&result, j);
 
        return result;