io_loop.add_handler(sock.fileno(), accept_handler, IOLoop.READ)
-class Resolver(object):
+class BaseResolver(object):
+ def getaddrinfo(self, *args, **kwargs):
+ """Resolves an address.
+
+ The arguments to this function are the same as to
+ `socket.getaddrinfo`, with the addition of an optional
+ keyword-only ``callback`` argument.
+
+ Returns a `Future` whose result is the same as the return
+ value of `socket.getaddrinfo`. If a callback is passed,
+ it will be run with the `Future` as an argument when it
+ is complete.
+ """
+ raise NotImplementedError()
+
+
+class Resolver(BaseResolver):
def __init__(self, io_loop=None, executor=None):
self.io_loop = io_loop or IOLoop.instance()
self.executor = executor or dummy_executor
def getaddrinfo(self, *args, **kwargs):
return socket.getaddrinfo(*args, **kwargs)
+class OverrideResolver(BaseResolver):
+ """Wraps a resolver with a mapping of overrides.
+
+ This can be used to make local DNS changes (e.g. for testing)
+ without modifying system-wide settings.
+
+ The mapping can contain either host strings or host-port pairs.
+ """
+ def __init__(self, resolver, mapping):
+ self.resolver = resolver
+ self.mapping = mapping
+
+ def getaddrinfo(self, host, port, *args, **kwargs):
+ if (host, port) in self.mapping:
+ host, port = self.mapping[(host, port)]
+ elif host in self.mapping:
+ host = self.mapping[host]
+ return self.resolver.getaddrinfo(host, port, *args, **kwargs)
+
+
# These are the keyword arguments to ssl.wrap_socket that must be translated
# to their SSLContext equivalents (the other arguments are still passed
from tornado.httpclient import HTTPRequest, HTTPResponse, HTTPError, AsyncHTTPClient, main, _RequestProxy
from tornado.httputil import HTTPHeaders
from tornado.iostream import IOStream, SSLIOStream
-from tornado.netutil import Resolver
+from tornado.netutil import Resolver, OverrideResolver
from tornado.log import gen_log
from tornado import stack_context
from tornado.util import GzipDecompressor
self.max_clients = max_clients
self.queue = collections.deque()
self.active = {}
- self.hostname_mapping = hostname_mapping
self.max_buffer_size = max_buffer_size
self.resolver = resolver or Resolver(io_loop=io_loop)
+ if hostname_mapping is not None:
+ self.resolver = OverrideResolver(self.resolver, hostname_mapping)
def fetch_impl(self, request, callback):
callback = stack_context.wrap(callback)
# raw ipv6 addresses in urls are enclosed in brackets
host = host[1:-1]
self.parsed_hostname = host # save final host for _on_connect
- if self.client.hostname_mapping is not None:
- host = self.client.hostname_mapping.get(host, host)
if request.allow_ipv6:
af = socket.AF_UNSPEC
def setUp(self):
super(HostnameMappingTestCase, self).setUp()
self.http_client = SimpleAsyncHTTPClient(
- self.io_loop, hostname_mapping={'www.example.com': '127.0.0.1'})
+ self.io_loop,
+ hostname_mapping={
+ 'www.example.com': '127.0.0.1',
+ ('foo.example.com', 8000): ('127.0.0.1', self.get_http_port()),
+ })
def get_app(self):
return Application([url("/hello", HelloWorldHandler),])
response = self.wait()
response.rethrow()
self.assertEqual(response.body, b'Hello world!')
+
+ def test_port_mapping(self):
+ self.http_client.fetch('http://foo.example.com:8000/hello', self.stop)
+ response = self.wait()
+ response.rethrow()
+ self.assertEqual(response.body, b'Hello world!')