]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
PyUnicode_Join(): Two primary aims:
authorTim Peters <tim.peters@gmail.com>
Fri, 27 Aug 2004 01:49:32 +0000 (01:49 +0000)
committerTim Peters <tim.peters@gmail.com>
Fri, 27 Aug 2004 01:49:32 +0000 (01:49 +0000)
1. u1.join([u2]) is u2
2. Be more careful about C-level int overflow.

Since PySequence_Fast() isn't needed to achieve #1, it's not used -- but
the code could sure be simpler if it were.

Objects/unicodeobject.c

index 45fb966281a4f4560a59756b4ede30126c6ce17f..19b8c283a5331af69e22be2c0d4b45038daaf1d3 100644 (file)
@@ -3975,49 +3975,110 @@ int fixtitle(PyUnicodeObject *self)
     return 1;
 }
 
-PyObject *PyUnicode_Join(PyObject *separator,
-                        PyObject *seq)
+PyObject *
+PyUnicode_Join(PyObject *separator, PyObject *seq)
 {
+    PyObject *internal_separator = NULL;
     Py_UNICODE *sep;
-    int seplen;
+    size_t seplen;
     PyUnicodeObject *res = NULL;
-    int reslen = 0;
-    Py_UNICODE *p;
-    int sz = 100;
+    size_t sz;      /* # allocated bytes for string in res */
+    size_t reslen;  /* # used bytes */
+    Py_UNICODE *p;  /* pointer to free byte in res's string area */
+    PyObject *it;   /* iterator */
+    PyObject *item;
     int i;
-    PyObject *it;
+    PyObject *temp;
 
     it = PyObject_GetIter(seq);
     if (it == NULL)
         return NULL;
 
+    item = PyIter_Next(it);
+    if (item == NULL) {
+        if (PyErr_Occurred())
+            goto onError;
+        /* empty sequence; return u"" */
+        res = _PyUnicode_New(0);
+        goto Done;
+    }
+
+    /* If this is the only item, maybe we can get out cheap. */
+    res = (PyUnicodeObject *)item;
+    item = PyIter_Next(it);
+    if (item == NULL) {
+        if (PyErr_Occurred())
+            goto onError;
+        /* There's only one item in the sequence. */
+        if (PyUnicode_CheckExact(res)) /* whatever.join([u]) -> u */
+            goto Done;
+    }
+
+    /* There are at least two to join (item != NULL), or there's only
+     * one but it's not an exact Unicode (item == NULL).  res needs
+     * conversion to Unicode in either case.
+     * Caution:  we may need to ensure a copy is made, and that's trickier
+     * than it sounds because, e.g., PyUnicode_FromObject() may return
+     * a shared object (which must not be mutated).
+     */
+    if (! PyUnicode_Check(res) && ! PyString_Check(res)) {
+        PyErr_Format(PyExc_TypeError,
+                "sequence item 0: expected string or Unicode,"
+               " %.80s found",
+              res->ob_type->tp_name);
+       Py_XDECREF(item);
+        goto onError;
+    }
+    temp = PyUnicode_FromObject((PyObject *)res);
+    if (temp == NULL) {
+        Py_XDECREF(item);
+        goto onError;
+    }
+    Py_DECREF(res);
+    if (item == NULL) {
+       /* res was the only item */
+        res = (PyUnicodeObject *)temp;
+        goto Done;
+    }
+    /* There are at least two items.  As above, temp may be a shared object,
+     * so we need to copy it.
+     */
+    reslen = PyUnicode_GET_SIZE(temp);
+    sz = reslen + 100;  /* breathing room */
+    if (sz < reslen || sz > INT_MAX) /* overflow -- no breathing room */
+       sz = reslen;
+    res = _PyUnicode_New(sz);
+    if (res == NULL) {
+        Py_DECREF(item);
+        goto onError;
+    }
+    p = PyUnicode_AS_UNICODE(res);
+    Py_UNICODE_COPY(p, PyUnicode_AS_UNICODE(temp), (int)reslen);
+    p += reslen;
+    Py_DECREF(temp);
+
     if (separator == NULL) {
        Py_UNICODE blank = ' ';
        sep = &blank;
        seplen = 1;
     }
     else {
-       separator = PyUnicode_FromObject(separator);
-       if (separator == NULL)
+       internal_separator = PyUnicode_FromObject(separator);
+       if (internal_separator == NULL) {
+           Py_DECREF(item);
            goto onError;
-       sep = PyUnicode_AS_UNICODE(separator);
-       seplen = PyUnicode_GET_SIZE(separator);
+       }
+       sep = PyUnicode_AS_UNICODE(internal_separator);
+       seplen = PyUnicode_GET_SIZE(internal_separator);
     }
 
-    res = _PyUnicode_New(sz);
-    if (res == NULL)
-       goto onError;
-    p = PyUnicode_AS_UNICODE(res);
-    reslen = 0;
+    i = 1;
+    do {
+       size_t itemlen;
+       size_t newreslen;
 
-    for (i = 0; ; ++i) {
-       int itemlen;
-       PyObject *item = PyIter_Next(it);
-       if (item == NULL) {
-           if (PyErr_Occurred())
-               goto onError;
-           break;
-       }
+       /* Catenate the separator, then item. */
+       /* First convert item to Unicode. */
        if (!PyUnicode_Check(item)) {
            PyObject *v;
            if (!PyString_Check(item)) {
@@ -4034,36 +4095,55 @@ PyObject *PyUnicode_Join(PyObject *separator,
            if (item == NULL)
                goto onError;
        }
+        /* Make sure we have enough space for the separator and the item. */
        itemlen = PyUnicode_GET_SIZE(item);
-       while (reslen + itemlen + seplen >= sz) {
-           if (_PyUnicode_Resize(&res, sz*2) < 0) {
+       newreslen = reslen + seplen + itemlen;
+       if (newreslen < reslen ||  newreslen > INT_MAX)
+           goto Overflow;
+       if (newreslen > sz) {
+           do {
+               size_t oldsize = sz;
+               sz += sz;
+               if (sz < oldsize || sz > INT_MAX)
+                   goto Overflow;
+           } while (newreslen > sz);
+           if (_PyUnicode_Resize(&res, (int)sz) < 0) {
                Py_DECREF(item);
                goto onError;
            }
-           sz *= 2;
-           p = PyUnicode_AS_UNICODE(res) + reslen;
-       }
-       if (i > 0) {
-           Py_UNICODE_COPY(p, sep, seplen);
-           p += seplen;
-           reslen += seplen;
+            p = PyUnicode_AS_UNICODE(res) + reslen;
        }
-       Py_UNICODE_COPY(p, PyUnicode_AS_UNICODE(item), itemlen);
+       Py_UNICODE_COPY(p, sep, (int)seplen);
+       p += seplen;
+       Py_UNICODE_COPY(p, PyUnicode_AS_UNICODE(item), (int)itemlen);
        p += itemlen;
-       reslen += itemlen;
        Py_DECREF(item);
-    }
-    if (_PyUnicode_Resize(&res, reslen) < 0)
+       reslen = newreslen;
+
+        ++i;
+       item = PyIter_Next(it);
+    } while (item != NULL);
+    if (PyErr_Occurred())
        goto onError;
 
-    Py_XDECREF(separator);
+    if (_PyUnicode_Resize(&res, (int)reslen) < 0)
+       goto onError;
+
+ Done:
+    Py_XDECREF(internal_separator);
     Py_DECREF(it);
     return (PyObject *)res;
 
+ Overflow:
+    PyErr_SetString(PyExc_OverflowError,
+                    "join() is too long for a Python string");
+    Py_DECREF(item);
+    /* fall through */
+
  onError:
-    Py_XDECREF(separator);
-    Py_XDECREF(res);
+    Py_XDECREF(internal_separator);
     Py_DECREF(it);
+    Py_XDECREF(res);
     return NULL;
 }