]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Implement the "happy eyeballs" algorithm for ipv4/ipv6 selection.
authorBen Darnell <ben@bendarnell.com>
Sun, 4 May 2014 17:13:49 +0000 (13:13 -0400)
committerBen Darnell <ben@bendarnell.com>
Sun, 18 May 2014 14:19:31 +0000 (10:19 -0400)
IOStream.connect's Future resolves to the stream itself,
which streamlines some connection scenarios.

tornado/iostream.py
tornado/tcpclient.py
tornado/test/iostream_test.py
tornado/test/simple_httpclient_test.py
tornado/test/tcpclient_test.py
tornado/test/util.py

index b5440b2a8033640ce5194aeaae615179fb7fa1fa..c29494e2ebeb3667cf8e2a26e903f30d6d6a5510 100644 (file)
@@ -919,8 +919,9 @@ class IOStream(BaseIOStream):
         not previously connected.  The address parameter is in the
         same format as for `socket.connect <socket.socket.connect>`,
         i.e. a ``(host, port)`` tuple.  If ``callback`` is specified,
-        it will be called when the connection is completed; if not
-        this method returns a `.Future`.
+        it will be called with no arguments when the connection is
+        completed; if not this method returns a `.Future` (whose result
+        after a successful connection will be the stream itself).
 
         If specified, the ``server_hostname`` parameter will be used
         in SSL connections for certificate validation (if requested in
@@ -980,7 +981,7 @@ class IOStream(BaseIOStream):
         if self._connect_future is not None:
             future = self._connect_future
             self._connect_future = None
-            future.set_result(None)
+            future.set_result(self)
         self._connecting = False
 
     def set_nodelay(self, value):
index e41c39094fb01bbcb7a379e615e089b842431d45..1bb1253f1f35dd94d1c840d5770cc673957e3807 100644 (file)
 """
 from __future__ import absolute_import, division, print_function, with_statement
 
+import functools
 import socket
 
+from tornado.concurrent import Future
 from tornado.ioloop import IOLoop
 from tornado.iostream import IOStream, SSLIOStream
 from tornado import gen
 from tornado.netutil import Resolver
 
+_INITIAL_CONNECT_TIMEOUT = 0.3
+
+class _Connector(object):
+    """A stateless implementation of the "Happy Eyeballs" algorithm.
+
+    "Happy Eyeballs" is documented in RFC6555 as the recommended practice
+    for when both IPv4 and IPv6 addresses are available.
+
+    In this implementation, we partition the addresses by family, and
+    make the first connection attempt to whichever address was
+    returned first by ``getaddrinfo``.  If that connection fails or
+    times out, we begin a connection in parallel to the first address
+    of the other family.  If there are additional failures we retry
+    with other addresses, keeping one connection attempt per family
+    in flight at a time.
+
+    http://tools.ietf.org/html/rfc6555
+
+    """
+    def __init__(self, addrinfo, io_loop, connect):
+        self.io_loop = io_loop
+        self.connect = connect
+
+        self.future = Future()
+        self.timeout = None
+        self.last_error = None
+        self.remaining = len(addrinfo)
+        self.primary_addrs, self.secondary_addrs = self.split(addrinfo)
+
+    @staticmethod
+    def split(addrinfo):
+        """Partition the ``addrinfo`` list by address family.
+
+        Returns two lists.  The first list contains the first entry from
+        ``addrinfo`` and all others with the same family, and the
+        second list contains all other addresses (normally one list will
+        be AF_INET and the other AF_INET6, although non-standard resolvers
+        may return additional families).
+        """
+        primary = []
+        secondary = []
+        primary_af = addrinfo[0][0]
+        for af, addr in addrinfo:
+            if af == primary_af:
+                primary.append((af, addr))
+            else:
+                secondary.append((af, addr))
+        return primary, secondary
+
+    def start(self, timeout=_INITIAL_CONNECT_TIMEOUT):
+        self.try_connect(iter(self.primary_addrs))
+        self.set_timout(timeout)
+        return self.future
+
+    def try_connect(self, addrs):
+        try:
+            af, addr = next(addrs)
+        except StopIteration:
+            # We've reached the end of our queue, but the other queue
+            # might still be working.  Send a final error on the future
+            # only when both queues are finished.
+            if self.remaining == 0 and not self.future.done():
+                self.future.set_exception(self.last_error or
+                                          IOError("connection failed"))
+            return
+        future = self.connect(af, addr)
+        future.add_done_callback(functools.partial(self.on_connect_done,
+                                                   addrs, af, addr))
+
+    def on_connect_done(self, addrs, af, addr, future):
+        self.remaining -= 1
+        try:
+            stream = future.result()
+        except Exception as e:
+            if self.future.done():
+                return
+            # Error: try again (but remember what happened so we have an
+            # error to raise in the end)
+            self.last_error = e
+            self.try_connect(addrs)
+            if self.timeout is not None:
+                # If the first attempt failed, don't wait for the
+                # timeout to try an address from the secondary queue.
+                self.on_timeout()
+            return
+        self.clear_timeout()
+        if self.future.done():
+            # This is a late arrival; just drop it.
+            stream.close()
+        else:
+            self.future.set_result((af, addr, stream))
+
+    def set_timout(self, timeout):
+        self.timeout = self.io_loop.add_timeout(self.io_loop.time() + timeout,
+                                                self.on_timeout)
+
+    def on_timeout(self):
+        self.timeout = None
+        self.try_connect(iter(self.secondary_addrs))
+
+    def clear_timeout(self):
+        if self.timeout is not None:
+            self.io_loop.remove_timeout(self.timeout)
+
+
 class TCPClient(object):
     """A non-blocking TCP connection factory.
     """
@@ -50,18 +157,26 @@ class TCPClient(object):
         ``ssl_options`` is not None).
         """
         addrinfo = yield self.resolver.resolve(host, port, af)
-        af, addr = addrinfo[0]
-        stream = self._create_stream(af, ssl_options, max_buffer_size)
-        yield stream.connect(addr, server_hostname=host)
+        connector = _Connector(
+            addrinfo, self.io_loop,
+            functools.partial(self._create_stream,
+                              host, ssl_options, max_buffer_size))
+        af, addr, stream = yield connector.start()
+        # TODO: For better performance we could cache the (af, addr)
+        # information here and re-use it on sbusequent connections to
+        # the same host. (http://tools.ietf.org/html/rfc6555#section-4.2)
         raise gen.Return(stream)
 
-    def _create_stream(self, af, ssl_options, max_buffer_size):
+    def _create_stream(self, host, ssl_options, max_buffer_size, af, addr):
+        # TODO: we should connect in plaintext mode and start the
+        # ssl handshake only after stopping the _Connector.
         if ssl_options is None:
-            return IOStream(socket.socket(af),
-                            io_loop=self.io_loop,
-                            max_buffer_size=max_buffer_size)
+            stream = IOStream(socket.socket(af),
+                              io_loop=self.io_loop,
+                              max_buffer_size=max_buffer_size)
         else:
-            return SSLIOStream(socket.socket(af),
-                               io_loop=self.io_loop,
-                               ssl_options=ssl_options,
-                               max_buffer_size=max_buffer_size)
+            stream = SSLIOStream(socket.socket(af),
+                                 io_loop=self.io_loop,
+                                 ssl_options=ssl_options,
+                                 max_buffer_size=max_buffer_size)
+        return stream.connect(addr, server_hostname=host)
index baca340af209378978a16ad32bc8e9403b039adf..ac91cbd4e601c1ff2f6e1bc5141e0adbf698b135 100644 (file)
@@ -111,7 +111,9 @@ class TestIOStreamWebMixin(object):
     def test_future_interface(self):
         """Basic test of IOStream's ability to return Futures."""
         stream = self._make_client_iostream()
-        yield stream.connect(("localhost", self.get_http_port()))
+        connect_result = yield stream.connect(
+            ("localhost", self.get_http_port()))
+        self.assertIs(connect_result, stream)
         yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
         first_line = yield stream.read_until(b"\r\n")
         self.assertEqual(first_line, b"HTTP/1.0 200 OK\r\n")
index 16e48944f3075b8adbb8057faef736ae8920e9fe..e8349ed7e424839b132dbbc2ba7ee9142804062c 100644 (file)
@@ -20,7 +20,7 @@ from tornado.simple_httpclient import SimpleAsyncHTTPClient, _default_ca_certs
 from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler
 from tornado.test import httpclient_test
 from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog
-from tornado.test.util import unittest, skipOnTravis
+from tornado.test.util import skipOnTravis, skipIfNoIPv6
 from tornado.web import RequestHandler, Application, asynchronous, url, stream_request_body
 
 
@@ -242,7 +242,7 @@ class SimpleHTTPClientTestMixin(object):
         # trigger the hanging request to let it clean up after itself
         self.triggers.popleft()()
 
-    @unittest.skipIf(not socket.has_ipv6, 'ipv6 support not present')
+    @skipIfNoIPv6
     def test_ipv6(self):
         try:
             self.http_server.listen(self.get_http_port(), address='::1')
index 73bc3c103cbb19f606400925dc2e160f40f31df0..7a9882e85b3a716463cc5423feb6f8dceb73595c 100644 (file)
 from __future__ import absolute_import, division, print_function, with_statement
 
 from contextlib import closing
+import socket
 
+from tornado.concurrent import Future
 from tornado.log import gen_log
-from tornado.tcpclient import TCPClient
+from tornado.netutil import bind_sockets, Resolver
+from tornado.tcpclient import TCPClient, _Connector
 from tornado.tcpserver import TCPServer
 from tornado.testing import AsyncTestCase, bind_unused_port, gen_test, ExpectLog
+from tornado.test.util import skipIfNoIPv6, unittest
+
+# Fake address families for testing.  Used in place of AF_INET
+# and AF_INET6 because some installations do not have AF_INET6.
+AF1, AF2 = 1, 2
 
 class TestTCPServer(TCPServer):
-    def __init__(self):
+    def __init__(self, family):
         super(TestTCPServer, self).__init__()
         self.streams = []
-        socket, self.port = bind_unused_port()
-        self.add_socket(socket)
+        sockets = bind_sockets(None, 'localhost', family)
+        self.add_sockets(sockets)
+        self.port = sockets[0].getsockname()[1]
 
     def handle_stream(self, stream, address):
         self.streams.append(stream)
@@ -41,23 +50,60 @@ class TestTCPServer(TCPServer):
 class TCPClientTest(AsyncTestCase):
     def setUp(self):
         super(TCPClientTest, self).setUp()
-        self.server = TestTCPServer()
-        self.port = self.server.port
+        self.server = None
         self.client = TCPClient()
 
+    def start_server(self, family):
+        self.server = TestTCPServer(family)
+        return self.server.port
+
+    def stop_server(self):
+        if self.server is not None:
+            self.server.stop()
+            self.server = None
+
     def tearDown(self):
         self.client.close()
-        self.server.stop()
+        self.stop_server()
         super(TCPClientTest, self).tearDown()
 
     @gen_test
-    def test_connect_ipv4(self):
-        stream = yield self.client.connect('127.0.0.1', self.port)
+    def do_test_connect(self, family, host):
+        port = self.start_server(family)
+        stream = yield self.client.connect(host, port)
         with closing(stream):
             stream.write(b"hello")
             data = yield self.server.streams[0].read_bytes(5)
             self.assertEqual(data, b"hello")
 
+    def test_connect_ipv4_ipv4(self):
+        self.do_test_connect(socket.AF_INET, '127.0.0.1')
+
+    def test_connect_ipv4_dual(self):
+        with ExpectLog(gen_log, 'Connect error', required=False):
+            self.do_test_connect(socket.AF_INET, 'localhost')
+
+    @skipIfNoIPv6
+    def test_connect_ipv6_ipv6(self):
+        self.do_test_connect(socket.AF_INET6, '::1')
+
+    @skipIfNoIPv6
+    def test_connect_ipv6_dual(self):
+        if Resolver.configured_class().__name__.endswith('TwistedResolver'):
+            self.skipTest('TwistedResolver does not support multiple addresses')
+        with ExpectLog(gen_log, 'Connect error', required=False):
+            self.do_test_connect(socket.AF_INET6, 'localhost')
+
+    def test_connect_unspec_ipv4(self):
+        self.do_test_connect(socket.AF_UNSPEC, '127.0.0.1')
+
+    @skipIfNoIPv6
+    def test_connect_unspec_ipv6(self):
+        self.do_test_connect(socket.AF_UNSPEC, '::1')
+
+    def test_connect_unspec_dual(self):
+        self.do_test_connect(socket.AF_UNSPEC, 'localhost')
+
     @gen_test
     def test_refused_ipv4(self):
         sock, port = bind_unused_port()
@@ -65,3 +111,157 @@ class TCPClientTest(AsyncTestCase):
         with ExpectLog(gen_log, 'Connect error'):
             with self.assertRaises(IOError):
                 yield self.client.connect('127.0.0.1', port)
+
+
+class TestConnectorSplit(unittest.TestCase):
+    def test_one_family(self):
+        # These addresses aren't in the right format, but split doesn't care.
+        primary, secondary = _Connector.split(
+            [(AF1, 'a'),
+             (AF1, 'b')])
+        self.assertEqual(primary, [(AF1, 'a'),
+                                   (AF1, 'b')])
+        self.assertEqual(secondary, [])
+
+    def test_mixed(self):
+        primary, secondary = _Connector.split(
+            [(AF1, 'a'),
+             (AF2, 'b'),
+             (AF1, 'c'),
+             (AF2, 'd')])
+        self.assertEqual(primary, [(AF1, 'a'), (AF1, 'c')])
+        self.assertEqual(secondary, [(AF2, 'b'), (AF2, 'd')])
+
+
+class ConnectorTest(AsyncTestCase):
+    class FakeStream(object):
+        def __init__(self):
+            self.closed = False
+
+        def close(self):
+            self.closed = True
+
+    def setUp(self):
+        super(ConnectorTest, self).setUp()
+        self.connect_futures = {}
+        self.streams = {}
+        self.addrinfo = [(AF1, 'a'), (AF1, 'b'),
+                         (AF2, 'c'), (AF2, 'd')]
+
+    def tearDown(self):
+        # Unless explicitly checked (and popped) in the test, we shouldn't
+        # be closing any streams
+        for stream in self.streams.values():
+            self.assertFalse(stream.closed)
+        super(ConnectorTest, self).tearDown()
+
+    def create_stream(self, af, addr):
+        future = Future()
+        self.connect_futures[(af, addr)] = future
+        return future
+
+    def assert_pending(self, *keys):
+        self.assertEqual(sorted(self.connect_futures.keys()), sorted(keys))
+
+    def resolve_connect(self, af, addr, success):
+        future = self.connect_futures.pop((af, addr))
+        if success:
+            self.streams[addr] = ConnectorTest.FakeStream()
+            future.set_result(self.streams[addr])
+        else:
+            future.set_exception(IOError())
+
+    def start_connect(self, addrinfo):
+        conn = _Connector(addrinfo, self.io_loop, self.create_stream)
+        # Give it a huge timeout; we'll trigger timeouts manually.
+        future = conn.start(3600)
+        return conn, future
+
+    def test_immediate_success(self):
+        conn, future = self.start_connect(self.addrinfo)
+        self.assertEqual(list(self.connect_futures.keys()),
+                         [(AF1, 'a')])
+        self.resolve_connect(AF1, 'a', True)
+        self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
+
+    def test_immediate_failure(self):
+        # Fail with just one address.
+        conn, future = self.start_connect([(AF1, 'a')])
+        self.assert_pending((AF1, 'a'))
+        self.resolve_connect(AF1, 'a', False)
+        self.assertRaises(IOError, future.result)
+
+    def test_one_family_second_try(self):
+        conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
+        self.assert_pending((AF1, 'a'))
+        self.resolve_connect(AF1, 'a', False)
+        self.assert_pending((AF1, 'b'))
+        self.resolve_connect(AF1, 'b', True)
+        self.assertEqual(future.result(), (AF1, 'b', self.streams['b']))
+
+    def test_one_family_second_try_failure(self):
+        conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
+        self.assert_pending((AF1, 'a'))
+        self.resolve_connect(AF1, 'a', False)
+        self.assert_pending((AF1, 'b'))
+        self.resolve_connect(AF1, 'b', False)
+        self.assertRaises(IOError, future.result)
+
+    def test_one_family_second_try_timeout(self):
+        conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
+        self.assert_pending((AF1, 'a'))
+        # trigger the timeout while the first lookup is pending;
+        # nothing happens.
+        conn.on_timeout()
+        self.assert_pending((AF1, 'a'))
+        self.resolve_connect(AF1, 'a', False)
+        self.assert_pending((AF1, 'b'))
+        self.resolve_connect(AF1, 'b', True)
+        self.assertEqual(future.result(), (AF1, 'b', self.streams['b']))
+
+    def test_two_families_immediate_failure(self):
+        conn, future = self.start_connect(self.addrinfo)
+        self.assert_pending((AF1, 'a'))
+        self.resolve_connect(AF1, 'a', False)
+        self.assert_pending((AF1, 'b'), (AF2, 'c'))
+        self.resolve_connect(AF1, 'b', False)
+        self.resolve_connect(AF2, 'c', True)
+        self.assertEqual(future.result(), (AF2, 'c', self.streams['c']))
+
+    def test_two_families_timeout(self):
+        conn, future = self.start_connect(self.addrinfo)
+        self.assert_pending((AF1, 'a'))
+        conn.on_timeout()
+        self.assert_pending((AF1, 'a'), (AF2, 'c'))
+        self.resolve_connect(AF2, 'c', True)
+        self.assertEqual(future.result(), (AF2, 'c', self.streams['c']))
+        # resolving 'a' after the connection has completed doesn't start 'b'
+        self.resolve_connect(AF1, 'a', False)
+        self.assert_pending()
+
+    def test_success_after_timeout(self):
+        conn, future = self.start_connect(self.addrinfo)
+        self.assert_pending((AF1, 'a'))
+        conn.on_timeout()
+        self.assert_pending((AF1, 'a'), (AF2, 'c'))
+        self.resolve_connect(AF1, 'a', True)
+        self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
+        # resolving 'c' after completion closes the connection.
+        self.resolve_connect(AF2, 'c', True)
+        self.assertTrue(self.streams.pop('c').closed)
+
+    def test_all_fail(self):
+        conn, future = self.start_connect(self.addrinfo)
+        self.assert_pending((AF1, 'a'))
+        conn.on_timeout()
+        self.assert_pending((AF1, 'a'), (AF2, 'c'))
+        self.resolve_connect(AF2, 'c', False)
+        self.assert_pending((AF1, 'a'), (AF2, 'd'))
+        self.resolve_connect(AF2, 'd', False)
+        # one queue is now empty
+        self.assert_pending((AF1, 'a'))
+        self.resolve_connect(AF1, 'a', False)
+        self.assert_pending((AF1, 'b'))
+        self.assertFalse(future.done())
+        self.resolve_connect(AF1, 'b', False)
+        self.assertRaises(IOError, future.result)
index 76fc2687369d762c7af8fe04db89c06376b180b8..d31bbba33d8019865d1c44abd43eb246fa639c1f 100644 (file)
@@ -1,6 +1,7 @@
 from __future__ import absolute_import, division, print_function, with_statement
 
 import os
+import socket
 import sys
 
 # Encapsulate the choice of unittest or unittest2 here.
@@ -25,3 +26,5 @@ skipOnTravis = unittest.skipIf('TRAVIS' in os.environ,
 # depend on an external network.
 skipIfNoNetwork = unittest.skipIf('NO_NETWORK' in os.environ,
                                   'network access disabled')
+
+skipIfNoIPv6 = unittest.skipIf(not socket.has_ipv6, 'ipv6 support not present')