From cc0e7624f16e19c3976fa122b5797b3424a11298 Mon Sep 17 00:00:00 2001 From: Ben Darnell Date: Sun, 4 May 2014 13:13:49 -0400 Subject: [PATCH] Implement the "happy eyeballs" algorithm for ipv4/ipv6 selection. IOStream.connect's Future resolves to the stream itself, which streamlines some connection scenarios. --- tornado/iostream.py | 7 +- tornado/tcpclient.py | 137 ++++++++++++++-- tornado/test/iostream_test.py | 4 +- tornado/test/simple_httpclient_test.py | 4 +- tornado/test/tcpclient_test.py | 218 ++++++++++++++++++++++++- tornado/test/util.py | 3 + 6 files changed, 347 insertions(+), 26 deletions(-) diff --git a/tornado/iostream.py b/tornado/iostream.py index b5440b2a8..c29494e2e 100644 --- a/tornado/iostream.py +++ b/tornado/iostream.py @@ -919,8 +919,9 @@ class IOStream(BaseIOStream): not previously connected. The address parameter is in the same format as for `socket.connect `, i.e. a ``(host, port)`` tuple. If ``callback`` is specified, - it will be called when the connection is completed; if not - this method returns a `.Future`. + it will be called with no arguments when the connection is + completed; if not this method returns a `.Future` (whose result + after a successful connection will be the stream itself). If specified, the ``server_hostname`` parameter will be used in SSL connections for certificate validation (if requested in @@ -980,7 +981,7 @@ class IOStream(BaseIOStream): if self._connect_future is not None: future = self._connect_future self._connect_future = None - future.set_result(None) + future.set_result(self) self._connecting = False def set_nodelay(self, value): diff --git a/tornado/tcpclient.py b/tornado/tcpclient.py index e41c39094..1bb1253f1 100644 --- a/tornado/tcpclient.py +++ b/tornado/tcpclient.py @@ -18,13 +18,120 @@ """ from __future__ import absolute_import, division, print_function, with_statement +import functools import socket +from tornado.concurrent import Future from tornado.ioloop import IOLoop from tornado.iostream import IOStream, SSLIOStream from tornado import gen from tornado.netutil import Resolver +_INITIAL_CONNECT_TIMEOUT = 0.3 + +class _Connector(object): + """A stateless implementation of the "Happy Eyeballs" algorithm. + + "Happy Eyeballs" is documented in RFC6555 as the recommended practice + for when both IPv4 and IPv6 addresses are available. + + In this implementation, we partition the addresses by family, and + make the first connection attempt to whichever address was + returned first by ``getaddrinfo``. If that connection fails or + times out, we begin a connection in parallel to the first address + of the other family. If there are additional failures we retry + with other addresses, keeping one connection attempt per family + in flight at a time. + + http://tools.ietf.org/html/rfc6555 + + """ + def __init__(self, addrinfo, io_loop, connect): + self.io_loop = io_loop + self.connect = connect + + self.future = Future() + self.timeout = None + self.last_error = None + self.remaining = len(addrinfo) + self.primary_addrs, self.secondary_addrs = self.split(addrinfo) + + @staticmethod + def split(addrinfo): + """Partition the ``addrinfo`` list by address family. + + Returns two lists. The first list contains the first entry from + ``addrinfo`` and all others with the same family, and the + second list contains all other addresses (normally one list will + be AF_INET and the other AF_INET6, although non-standard resolvers + may return additional families). + """ + primary = [] + secondary = [] + primary_af = addrinfo[0][0] + for af, addr in addrinfo: + if af == primary_af: + primary.append((af, addr)) + else: + secondary.append((af, addr)) + return primary, secondary + + def start(self, timeout=_INITIAL_CONNECT_TIMEOUT): + self.try_connect(iter(self.primary_addrs)) + self.set_timout(timeout) + return self.future + + def try_connect(self, addrs): + try: + af, addr = next(addrs) + except StopIteration: + # We've reached the end of our queue, but the other queue + # might still be working. Send a final error on the future + # only when both queues are finished. + if self.remaining == 0 and not self.future.done(): + self.future.set_exception(self.last_error or + IOError("connection failed")) + return + future = self.connect(af, addr) + future.add_done_callback(functools.partial(self.on_connect_done, + addrs, af, addr)) + + def on_connect_done(self, addrs, af, addr, future): + self.remaining -= 1 + try: + stream = future.result() + except Exception as e: + if self.future.done(): + return + # Error: try again (but remember what happened so we have an + # error to raise in the end) + self.last_error = e + self.try_connect(addrs) + if self.timeout is not None: + # If the first attempt failed, don't wait for the + # timeout to try an address from the secondary queue. + self.on_timeout() + return + self.clear_timeout() + if self.future.done(): + # This is a late arrival; just drop it. + stream.close() + else: + self.future.set_result((af, addr, stream)) + + def set_timout(self, timeout): + self.timeout = self.io_loop.add_timeout(self.io_loop.time() + timeout, + self.on_timeout) + + def on_timeout(self): + self.timeout = None + self.try_connect(iter(self.secondary_addrs)) + + def clear_timeout(self): + if self.timeout is not None: + self.io_loop.remove_timeout(self.timeout) + + class TCPClient(object): """A non-blocking TCP connection factory. """ @@ -50,18 +157,26 @@ class TCPClient(object): ``ssl_options`` is not None). """ addrinfo = yield self.resolver.resolve(host, port, af) - af, addr = addrinfo[0] - stream = self._create_stream(af, ssl_options, max_buffer_size) - yield stream.connect(addr, server_hostname=host) + connector = _Connector( + addrinfo, self.io_loop, + functools.partial(self._create_stream, + host, ssl_options, max_buffer_size)) + af, addr, stream = yield connector.start() + # TODO: For better performance we could cache the (af, addr) + # information here and re-use it on sbusequent connections to + # the same host. (http://tools.ietf.org/html/rfc6555#section-4.2) raise gen.Return(stream) - def _create_stream(self, af, ssl_options, max_buffer_size): + def _create_stream(self, host, ssl_options, max_buffer_size, af, addr): + # TODO: we should connect in plaintext mode and start the + # ssl handshake only after stopping the _Connector. if ssl_options is None: - return IOStream(socket.socket(af), - io_loop=self.io_loop, - max_buffer_size=max_buffer_size) + stream = IOStream(socket.socket(af), + io_loop=self.io_loop, + max_buffer_size=max_buffer_size) else: - return SSLIOStream(socket.socket(af), - io_loop=self.io_loop, - ssl_options=ssl_options, - max_buffer_size=max_buffer_size) + stream = SSLIOStream(socket.socket(af), + io_loop=self.io_loop, + ssl_options=ssl_options, + max_buffer_size=max_buffer_size) + return stream.connect(addr, server_hostname=host) diff --git a/tornado/test/iostream_test.py b/tornado/test/iostream_test.py index baca340af..ac91cbd4e 100644 --- a/tornado/test/iostream_test.py +++ b/tornado/test/iostream_test.py @@ -111,7 +111,9 @@ class TestIOStreamWebMixin(object): def test_future_interface(self): """Basic test of IOStream's ability to return Futures.""" stream = self._make_client_iostream() - yield stream.connect(("localhost", self.get_http_port())) + connect_result = yield stream.connect( + ("localhost", self.get_http_port())) + self.assertIs(connect_result, stream) yield stream.write(b"GET / HTTP/1.0\r\n\r\n") first_line = yield stream.read_until(b"\r\n") self.assertEqual(first_line, b"HTTP/1.0 200 OK\r\n") diff --git a/tornado/test/simple_httpclient_test.py b/tornado/test/simple_httpclient_test.py index 16e48944f..e8349ed7e 100644 --- a/tornado/test/simple_httpclient_test.py +++ b/tornado/test/simple_httpclient_test.py @@ -20,7 +20,7 @@ from tornado.simple_httpclient import SimpleAsyncHTTPClient, _default_ca_certs from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler from tornado.test import httpclient_test from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog -from tornado.test.util import unittest, skipOnTravis +from tornado.test.util import skipOnTravis, skipIfNoIPv6 from tornado.web import RequestHandler, Application, asynchronous, url, stream_request_body @@ -242,7 +242,7 @@ class SimpleHTTPClientTestMixin(object): # trigger the hanging request to let it clean up after itself self.triggers.popleft()() - @unittest.skipIf(not socket.has_ipv6, 'ipv6 support not present') + @skipIfNoIPv6 def test_ipv6(self): try: self.http_server.listen(self.get_http_port(), address='::1') diff --git a/tornado/test/tcpclient_test.py b/tornado/test/tcpclient_test.py index 73bc3c103..7a9882e85 100644 --- a/tornado/test/tcpclient_test.py +++ b/tornado/test/tcpclient_test.py @@ -17,18 +17,27 @@ from __future__ import absolute_import, division, print_function, with_statement from contextlib import closing +import socket +from tornado.concurrent import Future from tornado.log import gen_log -from tornado.tcpclient import TCPClient +from tornado.netutil import bind_sockets, Resolver +from tornado.tcpclient import TCPClient, _Connector from tornado.tcpserver import TCPServer from tornado.testing import AsyncTestCase, bind_unused_port, gen_test, ExpectLog +from tornado.test.util import skipIfNoIPv6, unittest + +# Fake address families for testing. Used in place of AF_INET +# and AF_INET6 because some installations do not have AF_INET6. +AF1, AF2 = 1, 2 class TestTCPServer(TCPServer): - def __init__(self): + def __init__(self, family): super(TestTCPServer, self).__init__() self.streams = [] - socket, self.port = bind_unused_port() - self.add_socket(socket) + sockets = bind_sockets(None, 'localhost', family) + self.add_sockets(sockets) + self.port = sockets[0].getsockname()[1] def handle_stream(self, stream, address): self.streams.append(stream) @@ -41,23 +50,60 @@ class TestTCPServer(TCPServer): class TCPClientTest(AsyncTestCase): def setUp(self): super(TCPClientTest, self).setUp() - self.server = TestTCPServer() - self.port = self.server.port + self.server = None self.client = TCPClient() + def start_server(self, family): + self.server = TestTCPServer(family) + return self.server.port + + def stop_server(self): + if self.server is not None: + self.server.stop() + self.server = None + def tearDown(self): self.client.close() - self.server.stop() + self.stop_server() super(TCPClientTest, self).tearDown() @gen_test - def test_connect_ipv4(self): - stream = yield self.client.connect('127.0.0.1', self.port) + def do_test_connect(self, family, host): + port = self.start_server(family) + stream = yield self.client.connect(host, port) with closing(stream): stream.write(b"hello") data = yield self.server.streams[0].read_bytes(5) self.assertEqual(data, b"hello") + def test_connect_ipv4_ipv4(self): + self.do_test_connect(socket.AF_INET, '127.0.0.1') + + def test_connect_ipv4_dual(self): + with ExpectLog(gen_log, 'Connect error', required=False): + self.do_test_connect(socket.AF_INET, 'localhost') + + @skipIfNoIPv6 + def test_connect_ipv6_ipv6(self): + self.do_test_connect(socket.AF_INET6, '::1') + + @skipIfNoIPv6 + def test_connect_ipv6_dual(self): + if Resolver.configured_class().__name__.endswith('TwistedResolver'): + self.skipTest('TwistedResolver does not support multiple addresses') + with ExpectLog(gen_log, 'Connect error', required=False): + self.do_test_connect(socket.AF_INET6, 'localhost') + + def test_connect_unspec_ipv4(self): + self.do_test_connect(socket.AF_UNSPEC, '127.0.0.1') + + @skipIfNoIPv6 + def test_connect_unspec_ipv6(self): + self.do_test_connect(socket.AF_UNSPEC, '::1') + + def test_connect_unspec_dual(self): + self.do_test_connect(socket.AF_UNSPEC, 'localhost') + @gen_test def test_refused_ipv4(self): sock, port = bind_unused_port() @@ -65,3 +111,157 @@ class TCPClientTest(AsyncTestCase): with ExpectLog(gen_log, 'Connect error'): with self.assertRaises(IOError): yield self.client.connect('127.0.0.1', port) + + +class TestConnectorSplit(unittest.TestCase): + def test_one_family(self): + # These addresses aren't in the right format, but split doesn't care. + primary, secondary = _Connector.split( + [(AF1, 'a'), + (AF1, 'b')]) + self.assertEqual(primary, [(AF1, 'a'), + (AF1, 'b')]) + self.assertEqual(secondary, []) + + def test_mixed(self): + primary, secondary = _Connector.split( + [(AF1, 'a'), + (AF2, 'b'), + (AF1, 'c'), + (AF2, 'd')]) + self.assertEqual(primary, [(AF1, 'a'), (AF1, 'c')]) + self.assertEqual(secondary, [(AF2, 'b'), (AF2, 'd')]) + + +class ConnectorTest(AsyncTestCase): + class FakeStream(object): + def __init__(self): + self.closed = False + + def close(self): + self.closed = True + + def setUp(self): + super(ConnectorTest, self).setUp() + self.connect_futures = {} + self.streams = {} + self.addrinfo = [(AF1, 'a'), (AF1, 'b'), + (AF2, 'c'), (AF2, 'd')] + + def tearDown(self): + # Unless explicitly checked (and popped) in the test, we shouldn't + # be closing any streams + for stream in self.streams.values(): + self.assertFalse(stream.closed) + super(ConnectorTest, self).tearDown() + + def create_stream(self, af, addr): + future = Future() + self.connect_futures[(af, addr)] = future + return future + + def assert_pending(self, *keys): + self.assertEqual(sorted(self.connect_futures.keys()), sorted(keys)) + + def resolve_connect(self, af, addr, success): + future = self.connect_futures.pop((af, addr)) + if success: + self.streams[addr] = ConnectorTest.FakeStream() + future.set_result(self.streams[addr]) + else: + future.set_exception(IOError()) + + def start_connect(self, addrinfo): + conn = _Connector(addrinfo, self.io_loop, self.create_stream) + # Give it a huge timeout; we'll trigger timeouts manually. + future = conn.start(3600) + return conn, future + + def test_immediate_success(self): + conn, future = self.start_connect(self.addrinfo) + self.assertEqual(list(self.connect_futures.keys()), + [(AF1, 'a')]) + self.resolve_connect(AF1, 'a', True) + self.assertEqual(future.result(), (AF1, 'a', self.streams['a'])) + + def test_immediate_failure(self): + # Fail with just one address. + conn, future = self.start_connect([(AF1, 'a')]) + self.assert_pending((AF1, 'a')) + self.resolve_connect(AF1, 'a', False) + self.assertRaises(IOError, future.result) + + def test_one_family_second_try(self): + conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')]) + self.assert_pending((AF1, 'a')) + self.resolve_connect(AF1, 'a', False) + self.assert_pending((AF1, 'b')) + self.resolve_connect(AF1, 'b', True) + self.assertEqual(future.result(), (AF1, 'b', self.streams['b'])) + + def test_one_family_second_try_failure(self): + conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')]) + self.assert_pending((AF1, 'a')) + self.resolve_connect(AF1, 'a', False) + self.assert_pending((AF1, 'b')) + self.resolve_connect(AF1, 'b', False) + self.assertRaises(IOError, future.result) + + def test_one_family_second_try_timeout(self): + conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')]) + self.assert_pending((AF1, 'a')) + # trigger the timeout while the first lookup is pending; + # nothing happens. + conn.on_timeout() + self.assert_pending((AF1, 'a')) + self.resolve_connect(AF1, 'a', False) + self.assert_pending((AF1, 'b')) + self.resolve_connect(AF1, 'b', True) + self.assertEqual(future.result(), (AF1, 'b', self.streams['b'])) + + def test_two_families_immediate_failure(self): + conn, future = self.start_connect(self.addrinfo) + self.assert_pending((AF1, 'a')) + self.resolve_connect(AF1, 'a', False) + self.assert_pending((AF1, 'b'), (AF2, 'c')) + self.resolve_connect(AF1, 'b', False) + self.resolve_connect(AF2, 'c', True) + self.assertEqual(future.result(), (AF2, 'c', self.streams['c'])) + + def test_two_families_timeout(self): + conn, future = self.start_connect(self.addrinfo) + self.assert_pending((AF1, 'a')) + conn.on_timeout() + self.assert_pending((AF1, 'a'), (AF2, 'c')) + self.resolve_connect(AF2, 'c', True) + self.assertEqual(future.result(), (AF2, 'c', self.streams['c'])) + # resolving 'a' after the connection has completed doesn't start 'b' + self.resolve_connect(AF1, 'a', False) + self.assert_pending() + + def test_success_after_timeout(self): + conn, future = self.start_connect(self.addrinfo) + self.assert_pending((AF1, 'a')) + conn.on_timeout() + self.assert_pending((AF1, 'a'), (AF2, 'c')) + self.resolve_connect(AF1, 'a', True) + self.assertEqual(future.result(), (AF1, 'a', self.streams['a'])) + # resolving 'c' after completion closes the connection. + self.resolve_connect(AF2, 'c', True) + self.assertTrue(self.streams.pop('c').closed) + + def test_all_fail(self): + conn, future = self.start_connect(self.addrinfo) + self.assert_pending((AF1, 'a')) + conn.on_timeout() + self.assert_pending((AF1, 'a'), (AF2, 'c')) + self.resolve_connect(AF2, 'c', False) + self.assert_pending((AF1, 'a'), (AF2, 'd')) + self.resolve_connect(AF2, 'd', False) + # one queue is now empty + self.assert_pending((AF1, 'a')) + self.resolve_connect(AF1, 'a', False) + self.assert_pending((AF1, 'b')) + self.assertFalse(future.done()) + self.resolve_connect(AF1, 'b', False) + self.assertRaises(IOError, future.result) diff --git a/tornado/test/util.py b/tornado/test/util.py index 76fc26873..d31bbba33 100644 --- a/tornado/test/util.py +++ b/tornado/test/util.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function, with_statement import os +import socket import sys # Encapsulate the choice of unittest or unittest2 here. @@ -25,3 +26,5 @@ skipOnTravis = unittest.skipIf('TRAVIS' in os.environ, # depend on an external network. skipIfNoNetwork = unittest.skipIf('NO_NETWORK' in os.environ, 'network access disabled') + +skipIfNoIPv6 = unittest.skipIf(not socket.has_ipv6, 'ipv6 support not present') -- 2.47.2