]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add follow_redirects support to SimpleAsyncHTTPClient.
authorBen Darnell <ben@bendarnell.com>
Tue, 15 Feb 2011 04:51:19 +0000 (20:51 -0800)
committerBen Darnell <ben@bendarnell.com>
Tue, 15 Feb 2011 04:51:19 +0000 (20:51 -0800)
tornado/simple_httpclient.py
tornado/test/simple_httpclient_test.py

index 7f4f64979032179efbe10a5e715b3f5cbc8dd688..2be988f422d2811ebe7c7cf8e0e5950e932af5f5 100644 (file)
@@ -10,6 +10,7 @@ from tornado import stack_context
 
 import collections
 import contextlib
+import copy
 import errno
 import functools
 import logging
@@ -279,8 +280,23 @@ class _HTTPConnection(object):
             buffer = StringIO()
         else:
             buffer = StringIO(data) # TODO: don't require one big string?
-        response = HTTPResponse(self.request, self.code, headers=self.headers,
-                                buffer=buffer)
+        original_request = getattr(self.request, "original_request",
+                                   self.request)
+        if (self.request.follow_redirects and
+            self.request.max_redirects > 0 and
+            self.code in (301, 302)):
+            new_request = copy.copy(self.request)
+            new_request.url = urlparse.urljoin(self.request.url,
+                                               self.headers["Location"])
+            new_request.max_redirects -= 1
+            new_request.original_request = original_request
+            self.client.fetch(new_request, self.callback)
+            self.callback = None
+            return
+        response = HTTPResponse(original_request,
+                                self.code, headers=self.headers,
+                                buffer=buffer,
+                                effective_url=self.request.url)
         self.callback(response)
         self.callback = None
 
index 479024ece64ab0853e04186bd2c998e5542971ae..2abdaecfe234bb054eb2d78d2c8566ac1d0bc3a4 100644 (file)
@@ -11,7 +11,7 @@ from contextlib import closing
 from tornado.ioloop import IOLoop
 from tornado.simple_httpclient import SimpleAsyncHTTPClient
 from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, get_unused_port
-from tornado.web import Application, RequestHandler, asynchronous
+from tornado.web import Application, RequestHandler, asynchronous, url
 
 class HelloWorldHandler(RequestHandler):
     def get(self):
@@ -50,18 +50,27 @@ class TriggerHandler(RequestHandler):
         self.queue.append(self.finish)
         self.wake_callback()
 
+class CountdownHandler(RequestHandler):
+    def get(self, count):
+        count = int(count)
+        if count > 0:
+            self.redirect(self.reverse_url("countdown", count - 1))
+        else:
+            self.write("Zero")
+
 class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase):
     def get_app(self):
         # callable objects to finish pending /trigger requests
         self.triggers = collections.deque()
         return Application([
-            ("/hello", HelloWorldHandler),
-            ("/post", PostHandler),
-            ("/chunk", ChunkHandler),
-            ("/auth", AuthHandler),
-            ("/hang", HangHandler),
-            ("/trigger", TriggerHandler, dict(queue=self.triggers,
-                                              wake_callback=self.stop)),
+            url("/hello", HelloWorldHandler),
+            url("/post", PostHandler),
+            url("/chunk", ChunkHandler),
+            url("/auth", AuthHandler),
+            url("/hang", HangHandler),
+            url("/trigger", TriggerHandler, dict(queue=self.triggers,
+                                                 wake_callback=self.stop)),
+            url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
             ], gzip=True)
 
     def setUp(self):
@@ -176,3 +185,22 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase):
         self.assertEqual(seen, [0, 1])
         self.assertEqual(len(client.queue), 0)
 
+    def test_follow_redirect(self):
+        response = self.fetch("/countdown/2", follow_redirects=False)
+        self.assertEqual(302, response.code)
+        self.assertTrue(response.headers["Location"].endswith("/countdown/1"))
+
+        response = self.fetch("/countdown/2")
+        self.assertEqual(200, response.code)
+        self.assertTrue(response.effective_url.endswith("/countdown/0"))
+        self.assertEqual("Zero", response.body)
+
+    def test_max_redirects(self):
+        response = self.fetch("/countdown/5", max_redirects=3)
+        self.assertEqual(302, response.code)
+        # We requested 5, followed three redirects for 4, 3, 2, then the last
+        # unfollowed redirect is to 1.
+        self.assertTrue(response.request.url.endswith("/countdown/5"))
+        self.assertTrue(response.effective_url.endswith("/countdown/2"))
+        self.assertTrue(response.headers["Location"].endswith("/countdown/1"))
+