]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Use new async resolver interface in simple_httpclient.
authorBen Darnell <ben@bendarnell.com>
Fri, 31 Aug 2012 20:44:06 +0000 (16:44 -0400)
committerBen Darnell <ben@bendarnell.com>
Fri, 31 Aug 2012 20:44:06 +0000 (16:44 -0400)
tornado/simple_httpclient.py
tornado/test/httpserver_test.py

index 1223620ab5915395449d27faffe229fac8e443a1..67107de4bdffad7bb9d532195feb2ca6f6c86727 100644 (file)
@@ -5,6 +5,7 @@ from tornado.escape import utf8, _unicode, native_str
 from tornado.httpclient import HTTPRequest, HTTPResponse, HTTPError, AsyncHTTPClient, main
 from tornado.httputil import HTTPHeaders
 from tornado.iostream import IOStream, SSLIOStream
+from tornado.netutil import Resolver
 from tornado import stack_context
 from tornado.util import b, GzipDecompressor
 
@@ -61,7 +62,8 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
 
     """
     def initialize(self, io_loop=None, max_clients=10,
-                   hostname_mapping=None, max_buffer_size=104857600):
+                   hostname_mapping=None, max_buffer_size=104857600,
+                   resolver=None):
         """Creates a AsyncHTTPClient.
 
         Only a single AsyncHTTPClient instance exists per IOLoop
@@ -87,6 +89,7 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
         self.active = {}
         self.hostname_mapping = hostname_mapping
         self.max_buffer_size = max_buffer_size
+        self.resolver = resolver or Resolver(io_loop=io_loop)
 
     def fetch(self, request, callback, **kwargs):
         if not isinstance(request, HTTPRequest):
@@ -130,6 +133,7 @@ class _HTTPConnection(object):
         self.request = request
         self.release_callback = release_callback
         self.final_callback = final_callback
+        self.max_buffer_size = max_buffer_size
         self.code = None
         self.headers = None
         self.chunks = None
@@ -137,16 +141,16 @@ class _HTTPConnection(object):
         # Timeout handle returned by IOLoop.add_timeout
         self._timeout = None
         with stack_context.StackContext(self.cleanup):
-            parsed = urlparse.urlsplit(_unicode(self.request.url))
-            if ssl is None and parsed.scheme == "https":
+            self.parsed = urlparse.urlsplit(_unicode(self.request.url))
+            if ssl is None and self.parsed.scheme == "https":
                 raise ValueError("HTTPS requires either python2.6+ or "
                                  "curl_httpclient")
-            if parsed.scheme not in ("http", "https"):
+            if self.parsed.scheme not in ("http", "https"):
                 raise ValueError("Unsupported url scheme: %s" %
                                  self.request.url)
             # urlsplit results have hostname and port results, but they
             # didn't support ipv6 literals until python 2.7.
-            netloc = parsed.netloc
+            netloc = self.parsed.netloc
             if "@" in netloc:
                 userpass, _, netloc = netloc.rpartition("@")
             match = re.match(r'^(.+):(\d+)$', netloc)
@@ -155,11 +159,11 @@ class _HTTPConnection(object):
                 port = int(match.group(2))
             else:
                 host = netloc
-                port = 443 if parsed.scheme == "https" else 80
+                port = 443 if self.parsed.scheme == "https" else 80
             if re.match(r'^\[.*\]$', host):
                 # raw ipv6 addresses in urls are enclosed in brackets
                 host = host[1:-1]
-            parsed_hostname = host  # save final parsed host for _on_connect
+            self.parsed_hostname = host  # save final host for _on_connect
             if self.client.hostname_mapping is not None:
                 host = self.client.hostname_mapping.get(host, host)
 
@@ -170,66 +174,67 @@ class _HTTPConnection(object):
                 # so restrict to ipv4 by default.
                 af = socket.AF_INET
 
-            addrinfo = socket.getaddrinfo(host, port, af, socket.SOCK_STREAM,
-                                          0, 0)
-            af, socktype, proto, canonname, sockaddr = addrinfo[0]
-
-            if parsed.scheme == "https":
-                ssl_options = {}
-                if request.validate_cert:
-                    ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
-                if request.ca_certs is not None:
-                    ssl_options["ca_certs"] = request.ca_certs
-                else:
-                    ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
-                if request.client_key is not None:
-                    ssl_options["keyfile"] = request.client_key
-                if request.client_cert is not None:
-                    ssl_options["certfile"] = request.client_cert
-
-                # SSL interoperability is tricky.  We want to disable
-                # SSLv2 for security reasons; it wasn't disabled by default
-                # until openssl 1.0.  The best way to do this is to use
-                # the SSL_OP_NO_SSLv2, but that wasn't exposed to python
-                # until 3.2.  Python 2.7 adds the ciphers argument, which
-                # can also be used to disable SSLv2.  As a last resort
-                # on python 2.6, we set ssl_version to SSLv3.  This is
-                # more narrow than we'd like since it also breaks
-                # compatibility with servers configured for TLSv1 only,
-                # but nearly all servers support SSLv3:
-                # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html
-                if sys.version_info >= (2, 7):
-                    ssl_options["ciphers"] = "DEFAULT:!SSLv2"
-                else:
-                    # This is really only necessary for pre-1.0 versions
-                    # of openssl, but python 2.6 doesn't expose version
-                    # information.
-                    ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3
-
-                self.stream = SSLIOStream(socket.socket(af, socktype, proto),
-                                          io_loop=self.io_loop,
-                                          ssl_options=ssl_options,
-                                          max_buffer_size=max_buffer_size)
+            self.client.resolver.getaddrinfo(
+                host, port, af, socket.SOCK_STREAM, 0, 0,
+                callback=self._on_resolve)
+
+    def _on_resolve(self, future):
+        af, socktype, proto, canonname, sockaddr = future.result()[0]
+
+        if self.parsed.scheme == "https":
+            ssl_options = {}
+            if self.request.validate_cert:
+                ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
+            if self.request.ca_certs is not None:
+                ssl_options["ca_certs"] = self.request.ca_certs
+            else:
+                ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
+            if self.request.client_key is not None:
+                ssl_options["keyfile"] = self.request.client_key
+            if self.request.client_cert is not None:
+                ssl_options["certfile"] = self.request.client_cert
+
+            # SSL interoperability is tricky.  We want to disable
+            # SSLv2 for security reasons; it wasn't disabled by default
+            # until openssl 1.0.  The best way to do this is to use
+            # the SSL_OP_NO_SSLv2, but that wasn't exposed to python
+            # until 3.2.  Python 2.7 adds the ciphers argument, which
+            # can also be used to disable SSLv2.  As a last resort
+            # on python 2.6, we set ssl_version to SSLv3.  This is
+            # more narrow than we'd like since it also breaks
+            # compatibility with servers configured for TLSv1 only,
+            # but nearly all servers support SSLv3:
+            # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html
+            if sys.version_info >= (2, 7):
+                ssl_options["ciphers"] = "DEFAULT:!SSLv2"
             else:
-                self.stream = IOStream(socket.socket(af, socktype, proto),
-                                       io_loop=self.io_loop,
-                                       max_buffer_size=max_buffer_size)
-            timeout = min(request.connect_timeout, request.request_timeout)
-            if timeout:
-                self._timeout = self.io_loop.add_timeout(
-                    self.start_time + timeout,
-                    stack_context.wrap(self._on_timeout))
-            self.stream.set_close_callback(self._on_close)
-            self.stream.connect(sockaddr,
-                                functools.partial(self._on_connect, parsed,
-                                                  parsed_hostname))
+                # This is really only necessary for pre-1.0 versions
+                # of openssl, but python 2.6 doesn't expose version
+                # information.
+                ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3
+
+            self.stream = SSLIOStream(socket.socket(af, socktype, proto),
+                                      io_loop=self.io_loop,
+                                      ssl_options=ssl_options,
+                                      max_buffer_size=self.max_buffer_size)
+        else:
+            self.stream = IOStream(socket.socket(af, socktype, proto),
+                                   io_loop=self.io_loop,
+                                   max_buffer_size=self.max_buffer_size)
+        timeout = min(self.request.connect_timeout, self.request.request_timeout)
+        if timeout:
+            self._timeout = self.io_loop.add_timeout(
+                self.start_time + timeout,
+                stack_context.wrap(self._on_timeout))
+        self.stream.set_close_callback(self._on_close)
+        self.stream.connect(sockaddr, self._on_connect)
 
     def _on_timeout(self):
         self._timeout = None
         if self.final_callback is not None:
             raise HTTPError(599, "Timeout")
 
-    def _on_connect(self, parsed, parsed_hostname):
+    def _on_connect(self):
         if self._timeout is not None:
             self.io_loop.remove_timeout(self._timeout)
             self._timeout = None
@@ -241,10 +246,10 @@ class _HTTPConnection(object):
             isinstance(self.stream, SSLIOStream)):
             match_hostname(self.stream.socket.getpeercert(),
                            # ipv6 addresses are broken (in
-                           # parsed.hostname) until 2.7, here is
+                           # self.parsed.hostname) until 2.7, here is
                            # correctly parsed value calculated in
                            # __init__
-                           parsed_hostname)
+                           self.parsed_hostname)
         if (self.request.method not in self._SUPPORTED_METHODS and
             not self.request.allow_nonstandard_methods):
             raise KeyError("unknown method %s" % self.request.method)
@@ -256,13 +261,13 @@ class _HTTPConnection(object):
         if "Connection" not in self.request.headers:
             self.request.headers["Connection"] = "close"
         if "Host" not in self.request.headers:
-            if '@' in parsed.netloc:
-                self.request.headers["Host"] = parsed.netloc.rpartition('@')[-1]
+            if '@' in self.parsed.netloc:
+                self.request.headers["Host"] = self.parsed.netloc.rpartition('@')[-1]
             else:
-                self.request.headers["Host"] = parsed.netloc
+                self.request.headers["Host"] = self.parsed.netloc
         username, password = None, None
-        if parsed.username is not None:
-            username, password = parsed.username, parsed.password
+        if self.parsed.username is not None:
+            username, password = self.parsed.username, self.parsed.password
         elif self.request.auth_username is not None:
             username = self.request.auth_username
             password = self.request.auth_password or ''
@@ -285,8 +290,8 @@ class _HTTPConnection(object):
             self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
         if self.request.use_gzip:
             self.request.headers["Accept-Encoding"] = "gzip"
-        req_path = ((parsed.path or '/') +
-                (('?' + parsed.query) if parsed.query else ''))
+        req_path = ((self.parsed.path or '/') +
+                (('?' + self.parsed.query) if self.parsed.query else ''))
         request_lines = [utf8("%s %s HTTP/1.1" % (self.request.method,
                                                   req_path))]
         for k, v in self.request.headers.get_all():
index 69af3a0ed4e47f7a23facf25477ac49e1099eaa9..e1df785a5e4f01772429de98962d2e04a8c0df48 100644 (file)
@@ -158,7 +158,7 @@ class RawRequestHTTPConnection(simple_httpclient._HTTPConnection):
     def set_request(self, request):
         self.__next_request = request
 
-    def _on_connect(self, parsed, parsed_hostname):
+    def _on_connect(self):
         self.stream.write(self.__next_request)
         self.__next_request = None
         self.stream.read_until(b("\r\n\r\n"), self._on_headers)