if self.request.expect_100_continue and first_line.code == 100:
self._write_body(False)
return
- self.headers = headers
self.code = first_line.code
self.reason = first_line.reason
+ self.headers = headers
+
+ if self._should_follow_redirect():
+ return
if self.request.header_callback is not None:
# Reassemble the start line.
self.request.header_callback("%s: %s\r\n" % (k, v))
self.request.header_callback('\r\n')
+ def _should_follow_redirect(self):
+ return (self.request.follow_redirects and
+ self.request.max_redirects > 0 and
+ self.code in (301, 302, 303, 307))
+
def finish(self):
data = b''.join(self.chunks)
self._remove_timeout()
original_request = getattr(self.request, "original_request",
self.request)
- if (self.request.follow_redirects and
- self.request.max_redirects > 0 and
- self.code in (301, 302, 303, 307)):
+ if self._should_follow_redirect():
assert isinstance(self.request, _RequestProxy)
new_request = copy.copy(self.request.request)
new_request.url = urlparse.urljoin(self.request.url,
self.stream.close()
def data_received(self, chunk):
+ if self._should_follow_redirect():
+ # We're going to follow a redirect so just discard the body.
+ return
if self.request.streaming_callback is not None:
self.request.streaming_callback(chunk)
else:
import datetime
from io import BytesIO
-from tornado.escape import utf8
+from tornado.escape import utf8, to_unicode
from tornado import gen
from tornado.httpclient import HTTPRequest, HTTPResponse, _RequestProxy, HTTPError, HTTPClient
from tornado.httpserver import HTTPServer
class RedirectHandler(RequestHandler):
def prepare(self):
+ self.write('redirects can have bodies too')
self.redirect(self.get_argument("url"),
status=int(self.get_argument("status", "302")))
response.rethrow()
self.assertEqual(response.body, b"Put body: hello")
+ def test_streaming_follow_redirects(self):
+ # When following redirects, header and streaming callbacks
+ # should only be called for the final result.
+ headers = []
+ chunks = []
+ self.fetch("/redirect?url=/hello",
+ header_callback=headers.append,
+ streaming_callback=chunks.append)
+ chunks = list(map(to_unicode, chunks))
+ self.assertEqual(chunks, ['Hello world!'])
+ # Make sure we only got one set of headers.
+ num_start_lines = len([h for h in headers if h.startswith("HTTP/")])
+ self.assertEqual(num_start_lines, 1)
+
+
class RequestProxyTest(unittest.TestCase):
def test_request_set(self):