]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add hooks for HTTPClient and HTTPConnection customization.
authorVladlen Y. Koshelev <vlad.kosh@gmail.com>
Wed, 13 Mar 2013 15:21:23 +0000 (19:21 +0400)
committerVladlen Y. Koshelev <vlad.kosh@gmail.com>
Wed, 13 Mar 2013 15:21:23 +0000 (19:21 +0400)
tornado/simple_httpclient.py

index ed344e7291c75c2cd9d13fd820ba5cad92a46150..9b6fb6cf1413d9452742777e1da5a9dc9746c1d8 100644 (file)
@@ -92,10 +92,12 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
                 request, callback = self.queue.popleft()
                 key = object()
                 self.active[key] = (request, callback)
-                _HTTPConnection(self.io_loop, self, request,
-                                functools.partial(self._release_fetch, key),
-                                callback,
-                                self.max_buffer_size, self.resolver)
+                release_callback = functools.partial(self._release_fetch, key)
+                self._handle_request(request, release_callback, callback)
+
+    def _handle_request(self, request, release_callback, final_callback):
+        _HTTPConnection(self.io_loop, self, request, release_callback,
+                        final_callback, self.max_buffer_size, self.resolver)
 
     def _release_fetch(self, key):
         del self.active[key]
@@ -153,8 +155,21 @@ class _HTTPConnection(object):
             self.resolver.resolve(host, port, af, callback=self._on_resolve)
 
     def _on_resolve(self, addrinfo):
-        af, sockaddr = addrinfo[0]
+        self.stream = self._create_stream(addrinfo)
+        timeout = min(self.request.connect_timeout, self.request.request_timeout)
+        if timeout:
+            self._timeout = self.io_loop.add_timeout(
+                self.start_time + timeout,
+                stack_context.wrap(self._on_timeout))
+        self.stream.set_close_callback(self._on_close)
+        # ipv6 addresses are broken (in self.parsed.hostname) until
+        # 2.7, here is correctly parsed value calculated in __init__
+        sockaddr = addrinfo[0][1]
+        self.stream.connect(sockaddr, self._on_connect,
+                            server_hostname=self.parsed_hostname)
 
+    def _create_stream(self, addrinfo):
+        af = addrinfo[0][0]
         if self.parsed.scheme == "https":
             ssl_options = {}
             if self.request.validate_cert:
@@ -187,24 +202,14 @@ class _HTTPConnection(object):
                 # information.
                 ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3
 
-            self.stream = SSLIOStream(socket.socket(af),
-                                      io_loop=self.io_loop,
-                                      ssl_options=ssl_options,
-                                      max_buffer_size=self.max_buffer_size)
+            return SSLIOStream(socket.socket(af),
+                               io_loop=self.io_loop,
+                               ssl_options=ssl_options,
+                               max_buffer_size=self.max_buffer_size)
         else:
-            self.stream = IOStream(socket.socket(af),
-                                   io_loop=self.io_loop,
-                                   max_buffer_size=self.max_buffer_size)
-        timeout = min(self.request.connect_timeout, self.request.request_timeout)
-        if timeout:
-            self._timeout = self.io_loop.add_timeout(
-                self.start_time + timeout,
-                stack_context.wrap(self._on_timeout))
-        self.stream.set_close_callback(self._on_close)
-        # ipv6 addresses are broken (in self.parsed.hostname) until
-        # 2.7, here is correctly parsed value calculated in __init__
-        self.stream.connect(sockaddr, self._on_connect,
-                            server_hostname=self.parsed_hostname)
+            return IOStream(socket.socket(af),
+                            io_loop=self.io_loop,
+                            max_buffer_size=self.max_buffer_size)
 
     def _on_timeout(self):
         self._timeout = None
@@ -412,7 +417,7 @@ class _HTTPConnection(object):
             self.final_callback = None
             self._release()
             self.client.fetch(new_request, final_callback)
-            self.stream.close()
+            self._on_end_request()
             return
         if self._decompressor:
             data = (self._decompressor.decompress(data) +
@@ -432,6 +437,9 @@ class _HTTPConnection(object):
                                 buffer=buffer,
                                 effective_url=self.request.url)
         self._run_callback(response)
+        self._on_end_request()
+
+    def _on_end_request(self):
         self.stream.close()
 
     def _on_chunk_length(self, data):