From: Ben Darnell Date: Fri, 31 Aug 2012 20:44:06 +0000 (-0400) Subject: Use new async resolver interface in simple_httpclient. X-Git-Tag: v3.0.0~263^2~7 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=df620e43046135603ba707db58821955ca909b4e;p=thirdparty%2Ftornado.git Use new async resolver interface in simple_httpclient. --- diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py index 1223620ab..67107de4b 100644 --- a/tornado/simple_httpclient.py +++ b/tornado/simple_httpclient.py @@ -5,6 +5,7 @@ from tornado.escape import utf8, _unicode, native_str from tornado.httpclient import HTTPRequest, HTTPResponse, HTTPError, AsyncHTTPClient, main from tornado.httputil import HTTPHeaders from tornado.iostream import IOStream, SSLIOStream +from tornado.netutil import Resolver from tornado import stack_context from tornado.util import b, GzipDecompressor @@ -61,7 +62,8 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): """ def initialize(self, io_loop=None, max_clients=10, - hostname_mapping=None, max_buffer_size=104857600): + hostname_mapping=None, max_buffer_size=104857600, + resolver=None): """Creates a AsyncHTTPClient. Only a single AsyncHTTPClient instance exists per IOLoop @@ -87,6 +89,7 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): self.active = {} self.hostname_mapping = hostname_mapping self.max_buffer_size = max_buffer_size + self.resolver = resolver or Resolver(io_loop=io_loop) def fetch(self, request, callback, **kwargs): if not isinstance(request, HTTPRequest): @@ -130,6 +133,7 @@ class _HTTPConnection(object): self.request = request self.release_callback = release_callback self.final_callback = final_callback + self.max_buffer_size = max_buffer_size self.code = None self.headers = None self.chunks = None @@ -137,16 +141,16 @@ class _HTTPConnection(object): # Timeout handle returned by IOLoop.add_timeout self._timeout = None with stack_context.StackContext(self.cleanup): - parsed = urlparse.urlsplit(_unicode(self.request.url)) - if ssl is None and parsed.scheme == "https": + self.parsed = urlparse.urlsplit(_unicode(self.request.url)) + if ssl is None and self.parsed.scheme == "https": raise ValueError("HTTPS requires either python2.6+ or " "curl_httpclient") - if parsed.scheme not in ("http", "https"): + if self.parsed.scheme not in ("http", "https"): raise ValueError("Unsupported url scheme: %s" % self.request.url) # urlsplit results have hostname and port results, but they # didn't support ipv6 literals until python 2.7. - netloc = parsed.netloc + netloc = self.parsed.netloc if "@" in netloc: userpass, _, netloc = netloc.rpartition("@") match = re.match(r'^(.+):(\d+)$', netloc) @@ -155,11 +159,11 @@ class _HTTPConnection(object): port = int(match.group(2)) else: host = netloc - port = 443 if parsed.scheme == "https" else 80 + port = 443 if self.parsed.scheme == "https" else 80 if re.match(r'^\[.*\]$', host): # raw ipv6 addresses in urls are enclosed in brackets host = host[1:-1] - parsed_hostname = host # save final parsed host for _on_connect + self.parsed_hostname = host # save final host for _on_connect if self.client.hostname_mapping is not None: host = self.client.hostname_mapping.get(host, host) @@ -170,66 +174,67 @@ class _HTTPConnection(object): # so restrict to ipv4 by default. af = socket.AF_INET - addrinfo = socket.getaddrinfo(host, port, af, socket.SOCK_STREAM, - 0, 0) - af, socktype, proto, canonname, sockaddr = addrinfo[0] - - if parsed.scheme == "https": - ssl_options = {} - if request.validate_cert: - ssl_options["cert_reqs"] = ssl.CERT_REQUIRED - if request.ca_certs is not None: - ssl_options["ca_certs"] = request.ca_certs - else: - ssl_options["ca_certs"] = _DEFAULT_CA_CERTS - if request.client_key is not None: - ssl_options["keyfile"] = request.client_key - if request.client_cert is not None: - ssl_options["certfile"] = request.client_cert - - # SSL interoperability is tricky. We want to disable - # SSLv2 for security reasons; it wasn't disabled by default - # until openssl 1.0. The best way to do this is to use - # the SSL_OP_NO_SSLv2, but that wasn't exposed to python - # until 3.2. Python 2.7 adds the ciphers argument, which - # can also be used to disable SSLv2. As a last resort - # on python 2.6, we set ssl_version to SSLv3. This is - # more narrow than we'd like since it also breaks - # compatibility with servers configured for TLSv1 only, - # but nearly all servers support SSLv3: - # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html - if sys.version_info >= (2, 7): - ssl_options["ciphers"] = "DEFAULT:!SSLv2" - else: - # This is really only necessary for pre-1.0 versions - # of openssl, but python 2.6 doesn't expose version - # information. - ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3 - - self.stream = SSLIOStream(socket.socket(af, socktype, proto), - io_loop=self.io_loop, - ssl_options=ssl_options, - max_buffer_size=max_buffer_size) + self.client.resolver.getaddrinfo( + host, port, af, socket.SOCK_STREAM, 0, 0, + callback=self._on_resolve) + + def _on_resolve(self, future): + af, socktype, proto, canonname, sockaddr = future.result()[0] + + if self.parsed.scheme == "https": + ssl_options = {} + if self.request.validate_cert: + ssl_options["cert_reqs"] = ssl.CERT_REQUIRED + if self.request.ca_certs is not None: + ssl_options["ca_certs"] = self.request.ca_certs + else: + ssl_options["ca_certs"] = _DEFAULT_CA_CERTS + if self.request.client_key is not None: + ssl_options["keyfile"] = self.request.client_key + if self.request.client_cert is not None: + ssl_options["certfile"] = self.request.client_cert + + # SSL interoperability is tricky. We want to disable + # SSLv2 for security reasons; it wasn't disabled by default + # until openssl 1.0. The best way to do this is to use + # the SSL_OP_NO_SSLv2, but that wasn't exposed to python + # until 3.2. Python 2.7 adds the ciphers argument, which + # can also be used to disable SSLv2. As a last resort + # on python 2.6, we set ssl_version to SSLv3. This is + # more narrow than we'd like since it also breaks + # compatibility with servers configured for TLSv1 only, + # but nearly all servers support SSLv3: + # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html + if sys.version_info >= (2, 7): + ssl_options["ciphers"] = "DEFAULT:!SSLv2" else: - self.stream = IOStream(socket.socket(af, socktype, proto), - io_loop=self.io_loop, - max_buffer_size=max_buffer_size) - timeout = min(request.connect_timeout, request.request_timeout) - if timeout: - self._timeout = self.io_loop.add_timeout( - self.start_time + timeout, - stack_context.wrap(self._on_timeout)) - self.stream.set_close_callback(self._on_close) - self.stream.connect(sockaddr, - functools.partial(self._on_connect, parsed, - parsed_hostname)) + # This is really only necessary for pre-1.0 versions + # of openssl, but python 2.6 doesn't expose version + # information. + ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3 + + self.stream = SSLIOStream(socket.socket(af, socktype, proto), + io_loop=self.io_loop, + ssl_options=ssl_options, + max_buffer_size=self.max_buffer_size) + else: + self.stream = IOStream(socket.socket(af, socktype, proto), + io_loop=self.io_loop, + max_buffer_size=self.max_buffer_size) + timeout = min(self.request.connect_timeout, self.request.request_timeout) + if timeout: + self._timeout = self.io_loop.add_timeout( + self.start_time + timeout, + stack_context.wrap(self._on_timeout)) + self.stream.set_close_callback(self._on_close) + self.stream.connect(sockaddr, self._on_connect) def _on_timeout(self): self._timeout = None if self.final_callback is not None: raise HTTPError(599, "Timeout") - def _on_connect(self, parsed, parsed_hostname): + def _on_connect(self): if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) self._timeout = None @@ -241,10 +246,10 @@ class _HTTPConnection(object): isinstance(self.stream, SSLIOStream)): match_hostname(self.stream.socket.getpeercert(), # ipv6 addresses are broken (in - # parsed.hostname) until 2.7, here is + # self.parsed.hostname) until 2.7, here is # correctly parsed value calculated in # __init__ - parsed_hostname) + self.parsed_hostname) if (self.request.method not in self._SUPPORTED_METHODS and not self.request.allow_nonstandard_methods): raise KeyError("unknown method %s" % self.request.method) @@ -256,13 +261,13 @@ class _HTTPConnection(object): if "Connection" not in self.request.headers: self.request.headers["Connection"] = "close" if "Host" not in self.request.headers: - if '@' in parsed.netloc: - self.request.headers["Host"] = parsed.netloc.rpartition('@')[-1] + if '@' in self.parsed.netloc: + self.request.headers["Host"] = self.parsed.netloc.rpartition('@')[-1] else: - self.request.headers["Host"] = parsed.netloc + self.request.headers["Host"] = self.parsed.netloc username, password = None, None - if parsed.username is not None: - username, password = parsed.username, parsed.password + if self.parsed.username is not None: + username, password = self.parsed.username, self.parsed.password elif self.request.auth_username is not None: username = self.request.auth_username password = self.request.auth_password or '' @@ -285,8 +290,8 @@ class _HTTPConnection(object): self.request.headers["Content-Type"] = "application/x-www-form-urlencoded" if self.request.use_gzip: self.request.headers["Accept-Encoding"] = "gzip" - req_path = ((parsed.path or '/') + - (('?' + parsed.query) if parsed.query else '')) + req_path = ((self.parsed.path or '/') + + (('?' + self.parsed.query) if self.parsed.query else '')) request_lines = [utf8("%s %s HTTP/1.1" % (self.request.method, req_path))] for k, v in self.request.headers.get_all(): diff --git a/tornado/test/httpserver_test.py b/tornado/test/httpserver_test.py index 69af3a0ed..e1df785a5 100644 --- a/tornado/test/httpserver_test.py +++ b/tornado/test/httpserver_test.py @@ -158,7 +158,7 @@ class RawRequestHTTPConnection(simple_httpclient._HTTPConnection): def set_request(self, request): self.__next_request = request - def _on_connect(self, parsed, parsed_hostname): + def _on_connect(self): self.stream.write(self.__next_request) self.__next_request = None self.stream.read_until(b("\r\n\r\n"), self._on_headers)