"""
from __future__ import absolute_import, division, print_function, with_statement
+import functools
import socket
+from tornado.concurrent import Future
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream, SSLIOStream
from tornado import gen
from tornado.netutil import Resolver
+_INITIAL_CONNECT_TIMEOUT = 0.3
+
+class _Connector(object):
+ """A stateless implementation of the "Happy Eyeballs" algorithm.
+
+ "Happy Eyeballs" is documented in RFC6555 as the recommended practice
+ for when both IPv4 and IPv6 addresses are available.
+
+ In this implementation, we partition the addresses by family, and
+ make the first connection attempt to whichever address was
+ returned first by ``getaddrinfo``. If that connection fails or
+ times out, we begin a connection in parallel to the first address
+ of the other family. If there are additional failures we retry
+ with other addresses, keeping one connection attempt per family
+ in flight at a time.
+
+ http://tools.ietf.org/html/rfc6555
+
+ """
+ def __init__(self, addrinfo, io_loop, connect):
+ self.io_loop = io_loop
+ self.connect = connect
+
+ self.future = Future()
+ self.timeout = None
+ self.last_error = None
+ self.remaining = len(addrinfo)
+ self.primary_addrs, self.secondary_addrs = self.split(addrinfo)
+
+ @staticmethod
+ def split(addrinfo):
+ """Partition the ``addrinfo`` list by address family.
+
+ Returns two lists. The first list contains the first entry from
+ ``addrinfo`` and all others with the same family, and the
+ second list contains all other addresses (normally one list will
+ be AF_INET and the other AF_INET6, although non-standard resolvers
+ may return additional families).
+ """
+ primary = []
+ secondary = []
+ primary_af = addrinfo[0][0]
+ for af, addr in addrinfo:
+ if af == primary_af:
+ primary.append((af, addr))
+ else:
+ secondary.append((af, addr))
+ return primary, secondary
+
+ def start(self, timeout=_INITIAL_CONNECT_TIMEOUT):
+ self.try_connect(iter(self.primary_addrs))
+ self.set_timout(timeout)
+ return self.future
+
+ def try_connect(self, addrs):
+ try:
+ af, addr = next(addrs)
+ except StopIteration:
+ # We've reached the end of our queue, but the other queue
+ # might still be working. Send a final error on the future
+ # only when both queues are finished.
+ if self.remaining == 0 and not self.future.done():
+ self.future.set_exception(self.last_error or
+ IOError("connection failed"))
+ return
+ future = self.connect(af, addr)
+ future.add_done_callback(functools.partial(self.on_connect_done,
+ addrs, af, addr))
+
+ def on_connect_done(self, addrs, af, addr, future):
+ self.remaining -= 1
+ try:
+ stream = future.result()
+ except Exception as e:
+ if self.future.done():
+ return
+ # Error: try again (but remember what happened so we have an
+ # error to raise in the end)
+ self.last_error = e
+ self.try_connect(addrs)
+ if self.timeout is not None:
+ # If the first attempt failed, don't wait for the
+ # timeout to try an address from the secondary queue.
+ self.on_timeout()
+ return
+ self.clear_timeout()
+ if self.future.done():
+ # This is a late arrival; just drop it.
+ stream.close()
+ else:
+ self.future.set_result((af, addr, stream))
+
+ def set_timout(self, timeout):
+ self.timeout = self.io_loop.add_timeout(self.io_loop.time() + timeout,
+ self.on_timeout)
+
+ def on_timeout(self):
+ self.timeout = None
+ self.try_connect(iter(self.secondary_addrs))
+
+ def clear_timeout(self):
+ if self.timeout is not None:
+ self.io_loop.remove_timeout(self.timeout)
+
+
class TCPClient(object):
"""A non-blocking TCP connection factory.
"""
``ssl_options`` is not None).
"""
addrinfo = yield self.resolver.resolve(host, port, af)
- af, addr = addrinfo[0]
- stream = self._create_stream(af, ssl_options, max_buffer_size)
- yield stream.connect(addr, server_hostname=host)
+ connector = _Connector(
+ addrinfo, self.io_loop,
+ functools.partial(self._create_stream,
+ host, ssl_options, max_buffer_size))
+ af, addr, stream = yield connector.start()
+ # TODO: For better performance we could cache the (af, addr)
+ # information here and re-use it on sbusequent connections to
+ # the same host. (http://tools.ietf.org/html/rfc6555#section-4.2)
raise gen.Return(stream)
- def _create_stream(self, af, ssl_options, max_buffer_size):
+ def _create_stream(self, host, ssl_options, max_buffer_size, af, addr):
+ # TODO: we should connect in plaintext mode and start the
+ # ssl handshake only after stopping the _Connector.
if ssl_options is None:
- return IOStream(socket.socket(af),
- io_loop=self.io_loop,
- max_buffer_size=max_buffer_size)
+ stream = IOStream(socket.socket(af),
+ io_loop=self.io_loop,
+ max_buffer_size=max_buffer_size)
else:
- return SSLIOStream(socket.socket(af),
- io_loop=self.io_loop,
- ssl_options=ssl_options,
- max_buffer_size=max_buffer_size)
+ stream = SSLIOStream(socket.socket(af),
+ io_loop=self.io_loop,
+ ssl_options=ssl_options,
+ max_buffer_size=max_buffer_size)
+ return stream.connect(addr, server_hostname=host)
from __future__ import absolute_import, division, print_function, with_statement
from contextlib import closing
+import socket
+from tornado.concurrent import Future
from tornado.log import gen_log
-from tornado.tcpclient import TCPClient
+from tornado.netutil import bind_sockets, Resolver
+from tornado.tcpclient import TCPClient, _Connector
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, bind_unused_port, gen_test, ExpectLog
+from tornado.test.util import skipIfNoIPv6, unittest
+
+# Fake address families for testing. Used in place of AF_INET
+# and AF_INET6 because some installations do not have AF_INET6.
+AF1, AF2 = 1, 2
class TestTCPServer(TCPServer):
- def __init__(self):
+ def __init__(self, family):
super(TestTCPServer, self).__init__()
self.streams = []
- socket, self.port = bind_unused_port()
- self.add_socket(socket)
+ sockets = bind_sockets(None, 'localhost', family)
+ self.add_sockets(sockets)
+ self.port = sockets[0].getsockname()[1]
def handle_stream(self, stream, address):
self.streams.append(stream)
class TCPClientTest(AsyncTestCase):
def setUp(self):
super(TCPClientTest, self).setUp()
- self.server = TestTCPServer()
- self.port = self.server.port
+ self.server = None
self.client = TCPClient()
+ def start_server(self, family):
+ self.server = TestTCPServer(family)
+ return self.server.port
+
+ def stop_server(self):
+ if self.server is not None:
+ self.server.stop()
+ self.server = None
+
def tearDown(self):
self.client.close()
- self.server.stop()
+ self.stop_server()
super(TCPClientTest, self).tearDown()
@gen_test
- def test_connect_ipv4(self):
- stream = yield self.client.connect('127.0.0.1', self.port)
+ def do_test_connect(self, family, host):
+ port = self.start_server(family)
+ stream = yield self.client.connect(host, port)
with closing(stream):
stream.write(b"hello")
data = yield self.server.streams[0].read_bytes(5)
self.assertEqual(data, b"hello")
+ def test_connect_ipv4_ipv4(self):
+ self.do_test_connect(socket.AF_INET, '127.0.0.1')
+
+ def test_connect_ipv4_dual(self):
+ with ExpectLog(gen_log, 'Connect error', required=False):
+ self.do_test_connect(socket.AF_INET, 'localhost')
+
+ @skipIfNoIPv6
+ def test_connect_ipv6_ipv6(self):
+ self.do_test_connect(socket.AF_INET6, '::1')
+
+ @skipIfNoIPv6
+ def test_connect_ipv6_dual(self):
+ if Resolver.configured_class().__name__.endswith('TwistedResolver'):
+ self.skipTest('TwistedResolver does not support multiple addresses')
+ with ExpectLog(gen_log, 'Connect error', required=False):
+ self.do_test_connect(socket.AF_INET6, 'localhost')
+
+ def test_connect_unspec_ipv4(self):
+ self.do_test_connect(socket.AF_UNSPEC, '127.0.0.1')
+
+ @skipIfNoIPv6
+ def test_connect_unspec_ipv6(self):
+ self.do_test_connect(socket.AF_UNSPEC, '::1')
+
+ def test_connect_unspec_dual(self):
+ self.do_test_connect(socket.AF_UNSPEC, 'localhost')
+
@gen_test
def test_refused_ipv4(self):
sock, port = bind_unused_port()
with ExpectLog(gen_log, 'Connect error'):
with self.assertRaises(IOError):
yield self.client.connect('127.0.0.1', port)
+
+
+class TestConnectorSplit(unittest.TestCase):
+ def test_one_family(self):
+ # These addresses aren't in the right format, but split doesn't care.
+ primary, secondary = _Connector.split(
+ [(AF1, 'a'),
+ (AF1, 'b')])
+ self.assertEqual(primary, [(AF1, 'a'),
+ (AF1, 'b')])
+ self.assertEqual(secondary, [])
+
+ def test_mixed(self):
+ primary, secondary = _Connector.split(
+ [(AF1, 'a'),
+ (AF2, 'b'),
+ (AF1, 'c'),
+ (AF2, 'd')])
+ self.assertEqual(primary, [(AF1, 'a'), (AF1, 'c')])
+ self.assertEqual(secondary, [(AF2, 'b'), (AF2, 'd')])
+
+
+class ConnectorTest(AsyncTestCase):
+ class FakeStream(object):
+ def __init__(self):
+ self.closed = False
+
+ def close(self):
+ self.closed = True
+
+ def setUp(self):
+ super(ConnectorTest, self).setUp()
+ self.connect_futures = {}
+ self.streams = {}
+ self.addrinfo = [(AF1, 'a'), (AF1, 'b'),
+ (AF2, 'c'), (AF2, 'd')]
+
+ def tearDown(self):
+ # Unless explicitly checked (and popped) in the test, we shouldn't
+ # be closing any streams
+ for stream in self.streams.values():
+ self.assertFalse(stream.closed)
+ super(ConnectorTest, self).tearDown()
+
+ def create_stream(self, af, addr):
+ future = Future()
+ self.connect_futures[(af, addr)] = future
+ return future
+
+ def assert_pending(self, *keys):
+ self.assertEqual(sorted(self.connect_futures.keys()), sorted(keys))
+
+ def resolve_connect(self, af, addr, success):
+ future = self.connect_futures.pop((af, addr))
+ if success:
+ self.streams[addr] = ConnectorTest.FakeStream()
+ future.set_result(self.streams[addr])
+ else:
+ future.set_exception(IOError())
+
+ def start_connect(self, addrinfo):
+ conn = _Connector(addrinfo, self.io_loop, self.create_stream)
+ # Give it a huge timeout; we'll trigger timeouts manually.
+ future = conn.start(3600)
+ return conn, future
+
+ def test_immediate_success(self):
+ conn, future = self.start_connect(self.addrinfo)
+ self.assertEqual(list(self.connect_futures.keys()),
+ [(AF1, 'a')])
+ self.resolve_connect(AF1, 'a', True)
+ self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
+
+ def test_immediate_failure(self):
+ # Fail with just one address.
+ conn, future = self.start_connect([(AF1, 'a')])
+ self.assert_pending((AF1, 'a'))
+ self.resolve_connect(AF1, 'a', False)
+ self.assertRaises(IOError, future.result)
+
+ def test_one_family_second_try(self):
+ conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
+ self.assert_pending((AF1, 'a'))
+ self.resolve_connect(AF1, 'a', False)
+ self.assert_pending((AF1, 'b'))
+ self.resolve_connect(AF1, 'b', True)
+ self.assertEqual(future.result(), (AF1, 'b', self.streams['b']))
+
+ def test_one_family_second_try_failure(self):
+ conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
+ self.assert_pending((AF1, 'a'))
+ self.resolve_connect(AF1, 'a', False)
+ self.assert_pending((AF1, 'b'))
+ self.resolve_connect(AF1, 'b', False)
+ self.assertRaises(IOError, future.result)
+
+ def test_one_family_second_try_timeout(self):
+ conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
+ self.assert_pending((AF1, 'a'))
+ # trigger the timeout while the first lookup is pending;
+ # nothing happens.
+ conn.on_timeout()
+ self.assert_pending((AF1, 'a'))
+ self.resolve_connect(AF1, 'a', False)
+ self.assert_pending((AF1, 'b'))
+ self.resolve_connect(AF1, 'b', True)
+ self.assertEqual(future.result(), (AF1, 'b', self.streams['b']))
+
+ def test_two_families_immediate_failure(self):
+ conn, future = self.start_connect(self.addrinfo)
+ self.assert_pending((AF1, 'a'))
+ self.resolve_connect(AF1, 'a', False)
+ self.assert_pending((AF1, 'b'), (AF2, 'c'))
+ self.resolve_connect(AF1, 'b', False)
+ self.resolve_connect(AF2, 'c', True)
+ self.assertEqual(future.result(), (AF2, 'c', self.streams['c']))
+
+ def test_two_families_timeout(self):
+ conn, future = self.start_connect(self.addrinfo)
+ self.assert_pending((AF1, 'a'))
+ conn.on_timeout()
+ self.assert_pending((AF1, 'a'), (AF2, 'c'))
+ self.resolve_connect(AF2, 'c', True)
+ self.assertEqual(future.result(), (AF2, 'c', self.streams['c']))
+ # resolving 'a' after the connection has completed doesn't start 'b'
+ self.resolve_connect(AF1, 'a', False)
+ self.assert_pending()
+
+ def test_success_after_timeout(self):
+ conn, future = self.start_connect(self.addrinfo)
+ self.assert_pending((AF1, 'a'))
+ conn.on_timeout()
+ self.assert_pending((AF1, 'a'), (AF2, 'c'))
+ self.resolve_connect(AF1, 'a', True)
+ self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
+ # resolving 'c' after completion closes the connection.
+ self.resolve_connect(AF2, 'c', True)
+ self.assertTrue(self.streams.pop('c').closed)
+
+ def test_all_fail(self):
+ conn, future = self.start_connect(self.addrinfo)
+ self.assert_pending((AF1, 'a'))
+ conn.on_timeout()
+ self.assert_pending((AF1, 'a'), (AF2, 'c'))
+ self.resolve_connect(AF2, 'c', False)
+ self.assert_pending((AF1, 'a'), (AF2, 'd'))
+ self.resolve_connect(AF2, 'd', False)
+ # one queue is now empty
+ self.assert_pending((AF1, 'a'))
+ self.resolve_connect(AF1, 'a', False)
+ self.assert_pending((AF1, 'b'))
+ self.assertFalse(future.done())
+ self.resolve_connect(AF1, 'b', False)
+ self.assertRaises(IOError, future.result)