From: Ben Darnell Date: Mon, 18 Feb 2013 00:08:10 +0000 (-0500) Subject: Make simple_httpclient's hostname_mapping a Resolver wrapper. X-Git-Tag: v3.0.0~111 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=20a3d626096079f938f089e882899d4947ad6fa8;p=thirdparty%2Ftornado.git Make simple_httpclient's hostname_mapping a Resolver wrapper. --- diff --git a/tornado/netutil.py b/tornado/netutil.py index fd01c91af..05d8468ba 100644 --- a/tornado/netutil.py +++ b/tornado/netutil.py @@ -141,7 +141,23 @@ def add_accept_handler(sock, callback, io_loop=None): io_loop.add_handler(sock.fileno(), accept_handler, IOLoop.READ) -class Resolver(object): +class BaseResolver(object): + def getaddrinfo(self, *args, **kwargs): + """Resolves an address. + + The arguments to this function are the same as to + `socket.getaddrinfo`, with the addition of an optional + keyword-only ``callback`` argument. + + Returns a `Future` whose result is the same as the return + value of `socket.getaddrinfo`. If a callback is passed, + it will be run with the `Future` as an argument when it + is complete. + """ + raise NotImplementedError() + + +class Resolver(BaseResolver): def __init__(self, io_loop=None, executor=None): self.io_loop = io_loop or IOLoop.instance() self.executor = executor or dummy_executor @@ -150,6 +166,26 @@ class Resolver(object): def getaddrinfo(self, *args, **kwargs): return socket.getaddrinfo(*args, **kwargs) +class OverrideResolver(BaseResolver): + """Wraps a resolver with a mapping of overrides. + + This can be used to make local DNS changes (e.g. for testing) + without modifying system-wide settings. + + The mapping can contain either host strings or host-port pairs. + """ + def __init__(self, resolver, mapping): + self.resolver = resolver + self.mapping = mapping + + def getaddrinfo(self, host, port, *args, **kwargs): + if (host, port) in self.mapping: + host, port = self.mapping[(host, port)] + elif host in self.mapping: + host = self.mapping[host] + return self.resolver.getaddrinfo(host, port, *args, **kwargs) + + # These are the keyword arguments to ssl.wrap_socket that must be translated # to their SSLContext equivalents (the other arguments are still passed diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py index a2d18a339..f4d05dbdc 100644 --- a/tornado/simple_httpclient.py +++ b/tornado/simple_httpclient.py @@ -5,7 +5,7 @@ from tornado.escape import utf8, _unicode, native_str from tornado.httpclient import HTTPRequest, HTTPResponse, HTTPError, AsyncHTTPClient, main, _RequestProxy from tornado.httputil import HTTPHeaders from tornado.iostream import IOStream, SSLIOStream -from tornado.netutil import Resolver +from tornado.netutil import Resolver, OverrideResolver from tornado.log import gen_log from tornado import stack_context from tornado.util import GzipDecompressor @@ -72,9 +72,10 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): self.max_clients = max_clients self.queue = collections.deque() self.active = {} - self.hostname_mapping = hostname_mapping self.max_buffer_size = max_buffer_size self.resolver = resolver or Resolver(io_loop=io_loop) + if hostname_mapping is not None: + self.resolver = OverrideResolver(self.resolver, hostname_mapping) def fetch_impl(self, request, callback): callback = stack_context.wrap(callback) @@ -140,8 +141,6 @@ class _HTTPConnection(object): # raw ipv6 addresses in urls are enclosed in brackets host = host[1:-1] self.parsed_hostname = host # save final host for _on_connect - if self.client.hostname_mapping is not None: - host = self.client.hostname_mapping.get(host, host) if request.allow_ipv6: af = socket.AF_UNSPEC diff --git a/tornado/test/simple_httpclient_test.py b/tornado/test/simple_httpclient_test.py index a729fb5a4..0c243e94c 100644 --- a/tornado/test/simple_httpclient_test.py +++ b/tornado/test/simple_httpclient_test.py @@ -363,7 +363,11 @@ class HostnameMappingTestCase(AsyncHTTPTestCase): def setUp(self): super(HostnameMappingTestCase, self).setUp() self.http_client = SimpleAsyncHTTPClient( - self.io_loop, hostname_mapping={'www.example.com': '127.0.0.1'}) + self.io_loop, + hostname_mapping={ + 'www.example.com': '127.0.0.1', + ('foo.example.com', 8000): ('127.0.0.1', self.get_http_port()), + }) def get_app(self): return Application([url("/hello", HelloWorldHandler),]) @@ -374,3 +378,9 @@ class HostnameMappingTestCase(AsyncHTTPTestCase): response = self.wait() response.rethrow() self.assertEqual(response.body, b'Hello world!') + + def test_port_mapping(self): + self.http_client.fetch('http://foo.example.com:8000/hello', self.stop) + response = self.wait() + response.rethrow() + self.assertEqual(response.body, b'Hello world!')