]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Make simple_httpclient's hostname_mapping a Resolver wrapper.
authorBen Darnell <ben@bendarnell.com>
Mon, 18 Feb 2013 00:08:10 +0000 (19:08 -0500)
committerBen Darnell <ben@bendarnell.com>
Mon, 18 Feb 2013 00:08:10 +0000 (19:08 -0500)
tornado/netutil.py
tornado/simple_httpclient.py
tornado/test/simple_httpclient_test.py

index fd01c91af0f7d57488005d3eaef18f5fc3921559..05d8468ba17d5a2fda19a057998ffb33c8cbb8ba 100644 (file)
@@ -141,7 +141,23 @@ def add_accept_handler(sock, callback, io_loop=None):
     io_loop.add_handler(sock.fileno(), accept_handler, IOLoop.READ)
 
 
-class Resolver(object):
+class BaseResolver(object):
+    def getaddrinfo(self, *args, **kwargs):
+        """Resolves an address.
+
+        The arguments to this function are the same as to
+        `socket.getaddrinfo`, with the addition of an optional
+        keyword-only ``callback`` argument.
+
+        Returns a `Future` whose result is the same as the return
+        value of `socket.getaddrinfo`.  If a callback is passed,
+        it will be run with the `Future` as an argument when it
+        is complete.
+        """
+        raise NotImplementedError()
+
+
+class Resolver(BaseResolver):
     def __init__(self, io_loop=None, executor=None):
         self.io_loop = io_loop or IOLoop.instance()
         self.executor = executor or dummy_executor
@@ -150,6 +166,26 @@ class Resolver(object):
     def getaddrinfo(self, *args, **kwargs):
         return socket.getaddrinfo(*args, **kwargs)
 
+class OverrideResolver(BaseResolver):
+    """Wraps a resolver with a mapping of overrides.
+
+    This can be used to make local DNS changes (e.g. for testing)
+    without modifying system-wide settings.
+
+    The mapping can contain either host strings or host-port pairs.
+    """
+    def __init__(self, resolver, mapping):
+        self.resolver = resolver
+        self.mapping = mapping
+
+    def getaddrinfo(self, host, port, *args, **kwargs):
+        if (host, port) in self.mapping:
+            host, port = self.mapping[(host, port)]
+        elif host in self.mapping:
+            host = self.mapping[host]
+        return self.resolver.getaddrinfo(host, port, *args, **kwargs)
+
+
 
 # These are the keyword arguments to ssl.wrap_socket that must be translated
 # to their SSLContext equivalents (the other arguments are still passed
index a2d18a3392e703118947d618bde0ba322427db70..f4d05dbdcb356dbaf457b3eb2f9c745fcee92f7d 100644 (file)
@@ -5,7 +5,7 @@ from tornado.escape import utf8, _unicode, native_str
 from tornado.httpclient import HTTPRequest, HTTPResponse, HTTPError, AsyncHTTPClient, main, _RequestProxy
 from tornado.httputil import HTTPHeaders
 from tornado.iostream import IOStream, SSLIOStream
-from tornado.netutil import Resolver
+from tornado.netutil import Resolver, OverrideResolver
 from tornado.log import gen_log
 from tornado import stack_context
 from tornado.util import GzipDecompressor
@@ -72,9 +72,10 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
         self.max_clients = max_clients
         self.queue = collections.deque()
         self.active = {}
-        self.hostname_mapping = hostname_mapping
         self.max_buffer_size = max_buffer_size
         self.resolver = resolver or Resolver(io_loop=io_loop)
+        if hostname_mapping is not None:
+            self.resolver = OverrideResolver(self.resolver, hostname_mapping)
 
     def fetch_impl(self, request, callback):
         callback = stack_context.wrap(callback)
@@ -140,8 +141,6 @@ class _HTTPConnection(object):
                 # raw ipv6 addresses in urls are enclosed in brackets
                 host = host[1:-1]
             self.parsed_hostname = host  # save final host for _on_connect
-            if self.client.hostname_mapping is not None:
-                host = self.client.hostname_mapping.get(host, host)
 
             if request.allow_ipv6:
                 af = socket.AF_UNSPEC
index a729fb5a4232823252507bbc314c9366386ae311..0c243e94ca897fe01fc35b3fbd10c2a97fb66eac 100644 (file)
@@ -363,7 +363,11 @@ class HostnameMappingTestCase(AsyncHTTPTestCase):
     def setUp(self):
         super(HostnameMappingTestCase, self).setUp()
         self.http_client = SimpleAsyncHTTPClient(
-            self.io_loop, hostname_mapping={'www.example.com': '127.0.0.1'})
+            self.io_loop,
+            hostname_mapping={
+                'www.example.com': '127.0.0.1',
+                ('foo.example.com', 8000): ('127.0.0.1', self.get_http_port()),
+                })
 
     def get_app(self):
         return Application([url("/hello", HelloWorldHandler),])
@@ -374,3 +378,9 @@ class HostnameMappingTestCase(AsyncHTTPTestCase):
         response = self.wait()
         response.rethrow()
         self.assertEqual(response.body, b'Hello world!')
+
+    def test_port_mapping(self):
+        self.http_client.fetch('http://foo.example.com:8000/hello', self.stop)
+        response = self.wait()
+        response.rethrow()
+        self.assertEqual(response.body, b'Hello world!')