]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Make unicode.join() work nice with iterators. This also required a change
authorTim Peters <tim.peters@gmail.com>
Sat, 5 May 2001 05:36:48 +0000 (05:36 +0000)
committerTim Peters <tim.peters@gmail.com>
Sat, 5 May 2001 05:36:48 +0000 (05:36 +0000)
to string.join(), so that when the latter figures out in midstream that
it really needs unicode.join() instead, unicode.join() can actually get
all the sequence elements (i.e., there's no guarantee that the sequence
passed to string.join() can be iterated over *again* by unicode.join(),
so string.join() must not pass on the original sequence object anymore).

Lib/test/test_iter.py
Misc/NEWS
Objects/stringobject.c
Objects/unicodeobject.c

index bfe032fc52c0a757a089e9086b5c40275b97e7c4..073ffb452e8826a4288904871524f5adcd188af8 100644 (file)
@@ -431,4 +431,45 @@ class TestCase(unittest.TestCase):
         d = {"one": 1, "two": 2, "three": 3}
         self.assertEqual(reduce(add, d), "".join(d.keys()))
 
+    def test_unicode_join_endcase(self):
+
+        # This class inserts a Unicode object into its argument's natural
+        # iteration, in the 3rd position.
+        class OhPhooey:
+            def __init__(self, seq):
+                self.it = iter(seq)
+                self.i = 0
+
+            def __iter__(self):
+                return self
+
+            def next(self):
+                i = self.i
+                self.i = i+1
+                if i == 2:
+                    return u"fooled you!"
+                return self.it.next()
+
+        f = open(TESTFN, "w")
+        try:
+            f.write("a\n" + "b\n" + "c\n")
+        finally:
+            f.close()
+
+        f = open(TESTFN, "r")
+        # Nasty:  string.join(s) can't know whether unicode.join() is needed
+        # until it's seen all of s's elements.  But in this case, f's
+        # iterator cannot be restarted.  So what we're testing here is
+        # whether string.join() can manage to remember everything it's seen
+        # and pass that on to unicode.join().
+        try:
+            got = " - ".join(OhPhooey(f))
+            self.assertEqual(got, u"a\n - b\n - fooled you! - c\n")
+        finally:
+            f.close()
+            try:
+                unlink(TESTFN)
+            except OSError:
+                pass
+
 run_unittest(TestCase)
index 0d7857f1a5b1627c823f16195af719ac58b242ad..d556afaf49789fad6c07e4587aa73bb984b86a39 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -25,7 +25,7 @@ Core
     reduce()
     string.join()
     tuple()
-    XXX TODO unicode.join()
+    unicode.join()
     XXX TODO zip()
     XXX TODO 'x in y'
 
index b9056792f20e699e3f552d4c5d637e8e2a26e9cc..87d7c1957ed119a0ebde7a02c818e9db500fa001 100644 (file)
@@ -861,8 +861,15 @@ string_join(PyStringObject *self, PyObject *args)
                item = PySequence_Fast_GET_ITEM(seq, i);
                if (!PyString_Check(item)){
                        if (PyUnicode_Check(item)) {
+                               /* Defer to Unicode join.
+                                * CAUTION:  There's no gurantee that the
+                                * original sequence can be iterated over
+                                * again, so we must pass seq here.
+                                */
+                               PyObject *result;
+                               result = PyUnicode_Join((PyObject *)self, seq);
                                Py_DECREF(seq);
-                               return PyUnicode_Join((PyObject *)self, orig);
+                               return result;
                        }
                        PyErr_Format(PyExc_TypeError,
                                     "sequence item %i: expected string,"
index e52d628a88a9d2d89cf75456c5096ad832ec8ba2..5da4d2f032efd33fa3d850cff0207780aa5b0a3b 100644 (file)
@@ -2724,10 +2724,11 @@ PyObject *PyUnicode_Join(PyObject *separator,
     int seqlen = 0;
     int sz = 100;
     int i;
+    PyObject *it;
 
-    seqlen = PySequence_Size(seq);
-    if (seqlen < 0 && PyErr_Occurred())
-       return NULL;
+    it = PyObject_GetIter(seq);
+    if (it == NULL)
+        return NULL;
 
     if (separator == NULL) {
        Py_UNICODE blank = ' ';
@@ -2737,7 +2738,7 @@ PyObject *PyUnicode_Join(PyObject *separator,
     else {
        separator = PyUnicode_FromObject(separator);
        if (separator == NULL)
-           return NULL;
+           goto onError;
        sep = PyUnicode_AS_UNICODE(separator);
        seplen = PyUnicode_GET_SIZE(separator);
     }
@@ -2748,13 +2749,14 @@ PyObject *PyUnicode_Join(PyObject *separator,
     p = PyUnicode_AS_UNICODE(res);
     reslen = 0;
 
-    for (i = 0; i < seqlen; i++) {
+    for (i = 0; ; ++i) {
        int itemlen;
-       PyObject *item;
-
-       item = PySequence_GetItem(seq, i);
-       if (item == NULL)
-           goto onError;
+       PyObject *item = PyIter_Next(it);
+       if (item == NULL) {
+           if (PyErr_Occurred())
+               goto onError;
+           break;
+       }
        if (!PyUnicode_Check(item)) {
            PyObject *v;
            v = PyUnicode_FromObject(item);
@@ -2784,11 +2786,13 @@ PyObject *PyUnicode_Join(PyObject *separator,
        goto onError;
 
     Py_XDECREF(separator);
+    Py_DECREF(it);
     return (PyObject *)res;
 
  onError:
     Py_XDECREF(separator);
-    Py_DECREF(res);
+    Py_XDECREF(res);
+    Py_DECREF(it);
     return NULL;
 }