From: Ben Darnell Date: Tue, 15 Feb 2011 04:51:19 +0000 (-0800) Subject: Add follow_redirects support to SimpleAsyncHTTPClient. X-Git-Tag: v1.2.0~22 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6ac2c58db792953f846f18d417343980f5eb1d48;p=thirdparty%2Ftornado.git Add follow_redirects support to SimpleAsyncHTTPClient. --- diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py index 7f4f64979..2be988f42 100644 --- a/tornado/simple_httpclient.py +++ b/tornado/simple_httpclient.py @@ -10,6 +10,7 @@ from tornado import stack_context import collections import contextlib +import copy import errno import functools import logging @@ -279,8 +280,23 @@ class _HTTPConnection(object): buffer = StringIO() else: buffer = StringIO(data) # TODO: don't require one big string? - response = HTTPResponse(self.request, self.code, headers=self.headers, - buffer=buffer) + original_request = getattr(self.request, "original_request", + self.request) + if (self.request.follow_redirects and + self.request.max_redirects > 0 and + self.code in (301, 302)): + new_request = copy.copy(self.request) + new_request.url = urlparse.urljoin(self.request.url, + self.headers["Location"]) + new_request.max_redirects -= 1 + new_request.original_request = original_request + self.client.fetch(new_request, self.callback) + self.callback = None + return + response = HTTPResponse(original_request, + self.code, headers=self.headers, + buffer=buffer, + effective_url=self.request.url) self.callback(response) self.callback = None diff --git a/tornado/test/simple_httpclient_test.py b/tornado/test/simple_httpclient_test.py index 479024ece..2abdaecfe 100644 --- a/tornado/test/simple_httpclient_test.py +++ b/tornado/test/simple_httpclient_test.py @@ -11,7 +11,7 @@ from contextlib import closing from tornado.ioloop import IOLoop from tornado.simple_httpclient import SimpleAsyncHTTPClient from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, get_unused_port -from tornado.web import Application, RequestHandler, asynchronous +from tornado.web import Application, RequestHandler, asynchronous, url class HelloWorldHandler(RequestHandler): def get(self): @@ -50,18 +50,27 @@ class TriggerHandler(RequestHandler): self.queue.append(self.finish) self.wake_callback() +class CountdownHandler(RequestHandler): + def get(self, count): + count = int(count) + if count > 0: + self.redirect(self.reverse_url("countdown", count - 1)) + else: + self.write("Zero") + class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase): def get_app(self): # callable objects to finish pending /trigger requests self.triggers = collections.deque() return Application([ - ("/hello", HelloWorldHandler), - ("/post", PostHandler), - ("/chunk", ChunkHandler), - ("/auth", AuthHandler), - ("/hang", HangHandler), - ("/trigger", TriggerHandler, dict(queue=self.triggers, - wake_callback=self.stop)), + url("/hello", HelloWorldHandler), + url("/post", PostHandler), + url("/chunk", ChunkHandler), + url("/auth", AuthHandler), + url("/hang", HangHandler), + url("/trigger", TriggerHandler, dict(queue=self.triggers, + wake_callback=self.stop)), + url("/countdown/([0-9]+)", CountdownHandler, name="countdown"), ], gzip=True) def setUp(self): @@ -176,3 +185,22 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase): self.assertEqual(seen, [0, 1]) self.assertEqual(len(client.queue), 0) + def test_follow_redirect(self): + response = self.fetch("/countdown/2", follow_redirects=False) + self.assertEqual(302, response.code) + self.assertTrue(response.headers["Location"].endswith("/countdown/1")) + + response = self.fetch("/countdown/2") + self.assertEqual(200, response.code) + self.assertTrue(response.effective_url.endswith("/countdown/0")) + self.assertEqual("Zero", response.body) + + def test_max_redirects(self): + response = self.fetch("/countdown/5", max_redirects=3) + self.assertEqual(302, response.code) + # We requested 5, followed three redirects for 4, 3, 2, then the last + # unfollowed redirect is to 1. + self.assertTrue(response.request.url.endswith("/countdown/5")) + self.assertTrue(response.effective_url.endswith("/countdown/2")) + self.assertTrue(response.headers["Location"].endswith("/countdown/1")) +