]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add a simple mechanism to override DNS lookups in SimpleAsyncHTTPClient.
authorBen Darnell <ben@bendarnell.com>
Mon, 6 Dec 2010 01:07:46 +0000 (17:07 -0800)
committerBen Darnell <ben@bendarnell.com>
Mon, 14 Feb 2011 23:40:32 +0000 (15:40 -0800)
Intended for use in SSL unittests, where we will need to make requests to
localhost using different domain names.

tornado/simple_httpclient.py

index 5819a0cd170d9fa94de07daa6c010a5df27927af..8fa4e7e28f19a9e33dbf7fba7b9b2c7f2d5d5008 100644 (file)
@@ -49,7 +49,8 @@ class SimpleAsyncHTTPClient(object):
 
     def __new__(cls, io_loop=None, max_clients=10,
                 max_simultaneous_connections=None,
-                force_instance=False):
+                force_instance=False,
+                hostname_mapping=None):
         """Creates a SimpleAsyncHTTPClient.
 
         Only a single SimpleAsyncHTTPClient instance exists per IOLoop
@@ -61,6 +62,11 @@ class SimpleAsyncHTTPClient(object):
         only for compatibility with the curl-based AsyncHTTPClient.  Note
         that these arguments are only used when the client is first created,
         and will be ignored when an existing client is reused.
+
+        hostname_mapping is a dictionary mapping hostnames to IP addresses.
+        It can be used to make local DNS changes when modifying system-wide
+        settings like /etc/hosts is not possible or desirable (e.g. in
+        unittests).
         """
         io_loop = io_loop or IOLoop.instance()
         if io_loop in cls._ASYNC_CLIENTS and not force_instance:
@@ -71,6 +77,7 @@ class SimpleAsyncHTTPClient(object):
             instance.max_clients = max_clients
             instance.queue = collections.deque()
             instance.active = {}
+            instance.hostname_mapping = hostname_mapping
             if not force_instance:
                 cls._ASYNC_CLIENTS[io_loop] = instance
             return instance
@@ -97,7 +104,7 @@ class SimpleAsyncHTTPClient(object):
                 request, callback = self.queue.popleft()
                 key = object()
                 self.active[key] = (request, callback)
-                _HTTPConnection(self.io_loop, request,
+                _HTTPConnection(self.io_loop, self, request,
                                 functools.partial(self._on_fetch_complete,
                                                   key, callback))
 
@@ -111,9 +118,10 @@ class SimpleAsyncHTTPClient(object):
 class _HTTPConnection(object):
     _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE"])
 
-    def __init__(self, io_loop, request, callback):
+    def __init__(self, io_loop, client, request, callback):
         self.start_time = time.time()
         self.io_loop = io_loop
+        self.client = client
         self.request = request
         self.callback = callback
         self.code = None
@@ -130,6 +138,8 @@ class _HTTPConnection(object):
             else:
                 host = parsed.netloc
                 port = 443 if parsed.scheme == "https" else 80
+            if self.client.hostname_mapping is not None:
+                host = self.client.hostname_mapping.get(host, host)
 
             if parsed.scheme == "https":
                 # TODO: cert verification, etc