]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add timeout support to simple_httpclient
authorBen Darnell <ben@bendarnell.com>
Sun, 14 Nov 2010 22:45:17 +0000 (14:45 -0800)
committerBen Darnell <ben@bendarnell.com>
Sun, 14 Nov 2010 22:45:17 +0000 (14:45 -0800)
tornado/simple_httpclient.py
tornado/test/simple_httpclient_test.py

index 2c657a8a89e933811cd57ff645f4788f8e08a941..86a6c0dd783d53b291d17c420ad9ca278128b3c9 100644 (file)
@@ -14,6 +14,7 @@ import functools
 import logging
 import re
 import socket
+import time
 import urlparse
 import zlib
 
@@ -63,6 +64,7 @@ class _HTTPConnection(object):
     _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE"])
 
     def __init__(self, io_loop, request, callback):
+        self.start_time = time.time()
         self.io_loop = io_loop
         self.request = request
         self.callback = callback
@@ -70,6 +72,8 @@ class _HTTPConnection(object):
         self.headers = None
         self.chunks = None
         self._decompressor = None
+        # Timeout handle returned by IOLoop.add_timeout
+        self._timeout = None
         with stack_context.StackContext(self.cleanup):
             parsed = urlparse.urlsplit(self.request.url)
             if ":" in parsed.netloc:
@@ -86,10 +90,30 @@ class _HTTPConnection(object):
             else:
                 self.stream = IOStream(socket.socket(),
                                        io_loop=self.io_loop)
+            timeout = min(request.connect_timeout, request.request_timeout)
+            if timeout:
+                self._connect_timeout = self.io_loop.add_timeout(
+                    self.start_time + timeout,
+                    self._on_timeout)
             self.stream.connect((host, port),
                                 functools.partial(self._on_connect, parsed))
 
+    def _on_timeout(self):
+        self._timeout = None
+        self.stream.close()
+        if self.callback is not None:
+            self.callback(HTTPResponse(self.request, 599,
+                                       error=HTTPError(599, "Timeout")))
+            self.callback = None
+
     def _on_connect(self, parsed):
+        if self._timeout is not None:
+            self.io_loop.remove_callback(self._timeout)
+            self._timeout = None
+        if self.request.request_timeout:
+            self._timeout = self.io_loop.add_timeout(
+                self.start_time + self.request.request_timeout,
+                self._on_timeout)
         if (self.request.method not in self._SUPPORTED_METHODS and
             not self.request.allow_nonstandard_methods):
             raise KeyError("unknown method %s" % self.request.method)
@@ -167,6 +191,9 @@ class _HTTPConnection(object):
                             "don't know how to read %s", self.request.url)
 
     def _on_body(self, data):
+        if self._timeout is not None:
+            self.io_loop.remove_timeout(self._timeout)
+            self._timeout = None
         if self._decompressor:
             data = self._decompressor.decompress(data)
         if self.request.streaming_callback:
index d72f01cc14c65148e391c65ba6e437433e186175..c99e547af9c6cbfcf5d04675925bd68f4911ab37 100644 (file)
@@ -2,10 +2,12 @@
 
 import gzip
 import logging
+import socket
 
+from contextlib import closing
 from tornado.simple_httpclient import SimpleAsyncHTTPClient
-from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase
-from tornado.web import Application, RequestHandler
+from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, get_unused_port
+from tornado.web import Application, RequestHandler, asynchronous
 
 class HelloWorldHandler(RequestHandler):
     def get(self):
@@ -28,6 +30,11 @@ class AuthHandler(RequestHandler):
     def get(self):
         self.finish(self.request.headers["Authorization"])
 
+class HangHandler(RequestHandler):
+    @asynchronous
+    def get(self):
+        pass
+
 class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase):
     def get_app(self):
         return Application([
@@ -35,6 +42,7 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase):
             ("/post", PostHandler),
             ("/chunk", ChunkHandler),
             ("/auth", AuthHandler),
+            ("/hang", HangHandler),
             ], gzip=True)
 
     def setUp(self):
@@ -94,3 +102,23 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase):
         self.assertEqual(len(response.body), 34)
         f = gzip.GzipFile(mode="r", fileobj=response.buffer)
         self.assertEqual(f.read(), "asdfqwer")
+
+    def test_connect_timeout(self):
+        # create a socket and bind it to a port, but don't
+        # call accept so the connection will timeout.
+        #get_unused_port()
+        port = get_unused_port()
+
+        with closing(socket.socket()) as sock:
+            sock.bind(('', port))
+            self.http_client.fetch("http://localhost:%d/" % port,
+                                   self.stop,
+                                   connect_timeout=0.1)
+            response = self.wait()
+            self.assertEqual(response.code, 599)
+            self.assertEqual(str(response.error), "HTTP 599: Timeout")
+
+    def test_request_timeout(self):
+        response = self.fetch('/hang', request_timeout=0.1)
+        self.assertEqual(response.code, 599)
+        self.assertEqual(str(response.error), "HTTP 599: Timeout")