]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Refactor connection logic from simple_httpclient to a new tcpclient module.
authorBen Darnell <ben@bendarnell.com>
Sat, 3 May 2014 15:40:01 +0000 (11:40 -0400)
committerBen Darnell <ben@bendarnell.com>
Sun, 18 May 2014 14:18:53 +0000 (10:18 -0400)
This is preparation for introducing connection pooling and better
handling of ipv6.

If an IOStream was closed due to an exception, its Futures will
now raise that exception instead of StreamClosedError.

tornado/http1connection.py
tornado/iostream.py
tornado/simple_httpclient.py
tornado/tcpclient.py [new file with mode: 0644]
tornado/test/httpserver_test.py
tornado/test/runtests.py
tornado/test/tcpclient_test.py [new file with mode: 0644]
tornado/websocket.py

index b8109eed486ace5ec56d71aeddd503e262e76a56..d71a3244a0b04cb52532743dbaf7d4104c211fdd 100644 (file)
@@ -604,7 +604,8 @@ class HTTP1ServerConnection(object):
                 request_delegate = delegate.start_request(self, conn)
                 try:
                     ret = yield conn.read_response(request_delegate)
-                except iostream.StreamClosedError:
+                except (iostream.StreamClosedError,
+                        iostream.UnsatisfiableReadError):
                     return
                 except Exception:
                     # TODO: this is probably too broad; it would be better to
index c8c0749433a2bd65b52c8ebf9442ed7cf1ab0b60..b5440b2a8033640ce5194aeaae615179fb7fa1fa 100644 (file)
@@ -72,6 +72,7 @@ class StreamClosedError(IOError):
     pass
 
 
+
 class UnsatisfiableReadError(Exception):
     """Exception raised when a read cannot be satisfied.
 
@@ -364,7 +365,7 @@ class BaseIOStream(object):
                 futures.append(self._connect_future)
                 self._connect_future = None
             for future in futures:
-                future.set_exception(StreamClosedError())
+                future.set_exception(self.error or StreamClosedError())
             if self._close_callback is not None:
                 cb = self._close_callback
                 self._close_callback = None
@@ -408,13 +409,19 @@ class BaseIOStream(object):
             gen_log.warning("Got events for closed stream %s", fd)
             return
         try:
+            if self._connecting:
+                # Most IOLoops will report a write failed connect
+                # with the WRITE event, but SelectIOLoop reports a
+                # READ as well so we must check for connecting before
+                # either.
+                self._handle_connect()
+            if self.closed():
+                return
             if events & self.io_loop.READ:
                 self._handle_read()
             if self.closed():
                 return
             if events & self.io_loop.WRITE:
-                if self._connecting:
-                    self._handle_connect()
                 self._handle_write()
             if self.closed():
                 return
index b7be7950f0ba32e3cd87b933cf54530becab5a7d..06d2ca80168e22253952cf5537fe45f01638b2e6 100644 (file)
@@ -10,6 +10,7 @@ from tornado.iostream import IOStream, SSLIOStream, StreamClosedError
 from tornado.netutil import Resolver, OverrideResolver
 from tornado.log import gen_log
 from tornado import stack_context
+from tornado.tcpclient import TCPClient
 
 import base64
 import collections
@@ -88,6 +89,8 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
         self.waiting = {}
         self.max_buffer_size = max_buffer_size
         self.max_header_size = max_header_size
+        # TCPClient could create a Resolver for us, but we have to do it
+        # ourselves to support hostname_mapping.
         if resolver:
             self.resolver = resolver
             self.own_resolver = False
@@ -97,11 +100,13 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
         if hostname_mapping is not None:
             self.resolver = OverrideResolver(resolver=self.resolver,
                                              mapping=hostname_mapping)
+        self.tcp_client = TCPClient(resolver=self.resolver, io_loop=io_loop)
 
     def close(self):
         super(SimpleAsyncHTTPClient, self).close()
         if self.own_resolver:
             self.resolver.close()
+        self.tcp_client.close()
 
     def fetch_impl(self, request, callback):
         key = object()
@@ -133,7 +138,7 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
 
     def _handle_request(self, request, release_callback, final_callback):
         _HTTPConnection(self.io_loop, self, request, release_callback,
-                        final_callback, self.max_buffer_size, self.resolver,
+                        final_callback, self.max_buffer_size, self.tcp_client,
                         self.max_header_size)
 
     def _release_fetch(self, key):
@@ -161,7 +166,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
     _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
 
     def __init__(self, io_loop, client, request, release_callback,
-                 final_callback, max_buffer_size, resolver,
+                 final_callback, max_buffer_size, tcp_client,
                  max_header_size):
         self.start_time = io_loop.time()
         self.io_loop = io_loop
@@ -170,7 +175,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
         self.release_callback = release_callback
         self.final_callback = final_callback
         self.max_buffer_size = max_buffer_size
-        self.resolver = resolver
+        self.tcp_client = tcp_client
         self.max_header_size = max_header_size
         self.code = None
         self.headers = None
@@ -208,28 +213,19 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
                 # so restrict to ipv4 by default.
                 af = socket.AF_INET
 
+            ssl_options = self._get_ssl_options(self.parsed.scheme)
+
             timeout = min(self.request.connect_timeout, self.request.request_timeout)
             if timeout:
                 self._timeout = self.io_loop.add_timeout(
                     self.start_time + timeout,
                     stack_context.wrap(self._on_timeout))
-            self.resolver.resolve(host, port, af, callback=self._on_resolve)
+            self.tcp_client.connect(host, port, af=af,
+                                    ssl_options=ssl_options,
+                                    callback=self._on_connect)
 
-    def _on_resolve(self, addrinfo):
-        if self.final_callback is None:
-            # final_callback is cleared if we've hit our timeout
-            return
-        self.stream = self._create_stream(addrinfo)
-        self.stream.set_close_callback(self._on_close)
-        # ipv6 addresses are broken (in self.parsed.hostname) until
-        # 2.7, here is correctly parsed value calculated in __init__
-        self._sockaddr = addrinfo[0][1]
-        self.stream.connect(self._sockaddr, self._on_connect,
-                            server_hostname=self.parsed_hostname)
-
-    def _create_stream(self, addrinfo):
-        af = addrinfo[0][0]
-        if self.parsed.scheme == "https":
+    def _get_ssl_options(self, scheme):
+        if scheme == "https":
             ssl_options = {}
             if self.request.validate_cert:
                 ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
@@ -262,15 +258,8 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
                 # of openssl, but python 2.6 doesn't expose version
                 # information.
                 ssl_options["ssl_version"] = ssl.PROTOCOL_TLSv1
-
-            return SSLIOStream(socket.socket(af),
-                               io_loop=self.io_loop,
-                               ssl_options=ssl_options,
-                               max_buffer_size=self.max_buffer_size)
-        else:
-            return IOStream(socket.socket(af),
-                            io_loop=self.io_loop,
-                            max_buffer_size=self.max_buffer_size)
+            return ssl_options
+        return None
 
     def _on_timeout(self):
         self._timeout = None
@@ -282,7 +271,13 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
             self.io_loop.remove_timeout(self._timeout)
             self._timeout = None
 
-    def _on_connect(self):
+    def _on_connect(self, stream):
+        if self.final_callback is None:
+            # final_callback is cleared if we've hit our timeout.
+            stream.close()
+            return
+        self.stream = stream
+        self.stream.set_close_callback(self._on_close)
         self._remove_timeout()
         if self.final_callback is None:
             return
diff --git a/tornado/tcpclient.py b/tornado/tcpclient.py
new file mode 100644 (file)
index 0000000..e41c390
--- /dev/null
@@ -0,0 +1,67 @@
+#!/usr/bin/env python
+#
+# Copyright 2014 Facebook
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""A non-blocking TCP connection factory.
+"""
+from __future__ import absolute_import, division, print_function, with_statement
+
+import socket
+
+from tornado.ioloop import IOLoop
+from tornado.iostream import IOStream, SSLIOStream
+from tornado import gen
+from tornado.netutil import Resolver
+
+class TCPClient(object):
+    """A non-blocking TCP connection factory.
+    """
+    def __init__(self, resolver=None, io_loop=None):
+        self.io_loop = io_loop or IOLoop.current()
+        if resolver is not None:
+            self.resolver = resolver
+            self._own_resolver = False
+        else:
+            self.resolver = Resolver(io_loop=io_loop)
+            self._own_resolver = True
+
+    def close(self):
+        if self._own_resolver:
+            self.resolver.close()
+
+    @gen.coroutine
+    def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None,
+                max_buffer_size=None):
+        """Connect to the given host and port.
+
+        Asynchronously returns an `.IOStream` (or `.SSLIOStream` if
+        ``ssl_options`` is not None).
+        """
+        addrinfo = yield self.resolver.resolve(host, port, af)
+        af, addr = addrinfo[0]
+        stream = self._create_stream(af, ssl_options, max_buffer_size)
+        yield stream.connect(addr, server_hostname=host)
+        raise gen.Return(stream)
+
+    def _create_stream(self, af, ssl_options, max_buffer_size):
+        if ssl_options is None:
+            return IOStream(socket.socket(af),
+                            io_loop=self.io_loop,
+                            max_buffer_size=max_buffer_size)
+        else:
+            return SSLIOStream(socket.socket(af),
+                               io_loop=self.io_loop,
+                               ssl_options=ssl_options,
+                               max_buffer_size=max_buffer_size)
index f9e920cd9a995ab4d91d3056e9e00889c78567c2..8387c23cf5bb7f4c256c77e096abb5013ddab9a9 100644 (file)
@@ -9,7 +9,7 @@ from tornado.http1connection import HTTP1Connection
 from tornado.httpserver import HTTPServer
 from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine
 from tornado.iostream import IOStream
-from tornado.log import gen_log
+from tornado.log import gen_log, app_log
 from tornado.netutil import ssl_options_to_context
 from tornado.simple_httpclient import SimpleAsyncHTTPClient
 from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog, gen_test
@@ -111,11 +111,15 @@ class SSLTestMixin(object):
         # connection, rather than waiting for a timeout or otherwise
         # misbehaving.
         with ExpectLog(gen_log, '(SSL Error|uncaught exception)'):
-            self.http_client.fetch(self.get_url("/").replace('https:', 'http:'),
-                                   self.stop,
-                                   request_timeout=3600,
-                                   connect_timeout=3600)
-            response = self.wait()
+            # TODO: this should go to gen_log, not app_log.  See TODO
+            # in http1connection.py (_server_request_loop)
+            with ExpectLog(app_log, 'Uncaught exception', required=False):
+                self.http_client.fetch(
+                    self.get_url("/").replace('https:', 'http:'),
+                    self.stop,
+                    request_timeout=3600,
+                    connect_timeout=3600)
+                response = self.wait()
         self.assertEqual(response.code, 599)
 
 # Python's SSL implementation differs significantly between versions.
index a1fb329523a68031dcb2eb369ec57d2c4f2aee52..c1c5746b07680a7b477de93b636d4364b0c94ad1 100644 (file)
@@ -40,6 +40,7 @@ TEST_MODULES = [
     'tornado.test.process_test',
     'tornado.test.simple_httpclient_test',
     'tornado.test.stack_context_test',
+    'tornado.test.tcpclient_test',
     'tornado.test.template_test',
     'tornado.test.testing_test',
     'tornado.test.twisted_test',
diff --git a/tornado/test/tcpclient_test.py b/tornado/test/tcpclient_test.py
new file mode 100644 (file)
index 0000000..73bc3c1
--- /dev/null
@@ -0,0 +1,67 @@
+#!/usr/bin/env python
+#
+# Copyright 2014 Facebook
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from __future__ import absolute_import, division, print_function, with_statement
+
+from contextlib import closing
+
+from tornado.log import gen_log
+from tornado.tcpclient import TCPClient
+from tornado.tcpserver import TCPServer
+from tornado.testing import AsyncTestCase, bind_unused_port, gen_test, ExpectLog
+
+class TestTCPServer(TCPServer):
+    def __init__(self):
+        super(TestTCPServer, self).__init__()
+        self.streams = []
+        socket, self.port = bind_unused_port()
+        self.add_socket(socket)
+
+    def handle_stream(self, stream, address):
+        self.streams.append(stream)
+
+    def stop(self):
+        super(TestTCPServer, self).stop()
+        for stream in self.streams:
+            stream.close()
+
+class TCPClientTest(AsyncTestCase):
+    def setUp(self):
+        super(TCPClientTest, self).setUp()
+        self.server = TestTCPServer()
+        self.port = self.server.port
+        self.client = TCPClient()
+
+    def tearDown(self):
+        self.client.close()
+        self.server.stop()
+        super(TCPClientTest, self).tearDown()
+
+    @gen_test
+    def test_connect_ipv4(self):
+        stream = yield self.client.connect('127.0.0.1', self.port)
+        with closing(stream):
+            stream.write(b"hello")
+            data = yield self.server.streams[0].read_bytes(5)
+            self.assertEqual(data, b"hello")
+
+    @gen_test
+    def test_refused_ipv4(self):
+        sock, port = bind_unused_port()
+        sock.close()
+        with ExpectLog(gen_log, 'Connect error'):
+            with self.assertRaises(IOError):
+                yield self.client.connect('127.0.0.1', port)
index 218413f8924d7574c4c5f1693f5fd230c3699498..3db215dba3f2db181670c27a88c51228eff094d4 100644 (file)
@@ -37,8 +37,8 @@ from tornado import httpclient, httputil
 from tornado.ioloop import IOLoop
 from tornado.iostream import StreamClosedError
 from tornado.log import gen_log, app_log
-from tornado.netutil import Resolver
 from tornado import simple_httpclient
+from tornado.tcpclient import TCPClient
 from tornado.util import bytes_type, unicode_type
 
 try:
@@ -821,10 +821,10 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
             'Sec-WebSocket-Version': '13',
         })
 
-        self.resolver = Resolver(io_loop=io_loop)
+        self.tcp_client = TCPClient(io_loop=io_loop)
         super(WebSocketClientConnection, self).__init__(
             io_loop, None, request, lambda: None, self._on_http_response,
-            104857600, self.resolver, 65536)
+            104857600, self.tcp_client, 65536)
 
     def close(self, code=None, reason=None):
         """Closes the websocket connection.