import functools
import socket
+import time
+import numbers
+import datetime
from tornado.concurrent import Future
from tornado.ioloop import IOLoop
from tornado import gen
from tornado.netutil import Resolver
from tornado.platform.auto import set_close_exec
+from tornado.gen import TimeoutError
+from tornado.util import timedelta_to_seconds
_INITIAL_CONNECT_TIMEOUT = 0.3
self.future = Future()
self.timeout = None
+ self.connect_timeout = None
self.last_error = None
self.remaining = len(addrinfo)
self.primary_addrs, self.secondary_addrs = self.split(addrinfo)
+ self.streams = set()
@staticmethod
def split(addrinfo):
secondary.append((af, addr))
return primary, secondary
- def start(self, timeout=_INITIAL_CONNECT_TIMEOUT):
+ def start(self, timeout=_INITIAL_CONNECT_TIMEOUT, connect_timeout=None):
self.try_connect(iter(self.primary_addrs))
- self.set_timout(timeout)
+ self.set_timeout(timeout)
+ if connect_timeout is not None:
+ self.set_connect_timeout(connect_timeout)
return self.future
def try_connect(self, addrs):
self.future.set_exception(self.last_error or
IOError("connection failed"))
return
- future = self.connect(af, addr)
+ stream, future = self.connect(af, addr)
+ self.streams.add(stream)
future.add_done_callback(functools.partial(self.on_connect_done,
addrs, af, addr))
self.io_loop.remove_timeout(self.timeout)
self.on_timeout()
return
- self.clear_timeout()
+ self.clear_timeouts()
if self.future.done():
# This is a late arrival; just drop it.
stream.close()
else:
+ self.streams.discard(stream)
self.future.set_result((af, addr, stream))
+ self.close_streams()
- def set_timout(self, timeout):
+ def set_timeout(self, timeout):
self.timeout = self.io_loop.add_timeout(self.io_loop.time() + timeout,
self.on_timeout)
def on_timeout(self):
self.timeout = None
- self.try_connect(iter(self.secondary_addrs))
+ if not self.future.done():
+ self.try_connect(iter(self.secondary_addrs))
def clear_timeout(self):
if self.timeout is not None:
self.io_loop.remove_timeout(self.timeout)
+ def set_connect_timeout(self, connect_timeout):
+ self.connect_timeout = self.io_loop.add_timeout(
+ connect_timeout, self.on_connect_timeout)
+
+ def on_connect_timeout(self):
+ if not self.future.done():
+ self.future.set_exception(TimeoutError())
+ self.close_streams()
+
+ def clear_timeouts(self):
+ if self.timeout is not None:
+ self.io_loop.remove_timeout(self.timeout)
+ if self.connect_timeout is not None:
+ self.io_loop.remove_timeout(self.connect_timeout)
+
+ def close_streams(self):
+ for stream in self.streams:
+ stream.close()
+
class TCPClient(object):
"""A non-blocking TCP connection factory.
@gen.coroutine
def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None,
- max_buffer_size=None, source_ip=None, source_port=None):
+ max_buffer_size=None, source_ip=None, source_port=None,
+ timeout=None):
"""Connect to the given host and port.
Asynchronously returns an `.IOStream` (or `.SSLIOStream` if
use a specific interface, it has to be handled outside
of Tornado as this depends very much on the platform.
+ Raises `TimeoutError` if the input future does not complete before
+ ``timeout``, which may be specified in any form allowed by
+ `.IOLoop.add_timeout` (i.e. a `datetime.timedelta` or an absolute time
+ relative to `.IOLoop.time`)
+
Similarly, when the user requires a certain source port, it can
be specified using the ``source_port`` arg.
.. versionchanged:: 4.5
Added the ``source_ip`` and ``source_port`` arguments.
"""
- addrinfo = yield self.resolver.resolve(host, port, af)
+ if timeout is not None:
+ if isinstance(timeout, numbers.Real):
+ timeout = IOLoop.current().time() + timeout
+ elif isinstance(timeout, datetime.timedelta):
+ timeout = IOLoop.current().time() + timedelta_to_seconds(timeout)
+ else:
+ raise TypeError("Unsupported timeout %r" % timeout)
+ if timeout is not None:
+ addrinfo = yield gen.with_timeout(
+ timeout, self.resolver.resolve(host, port, af))
+ else:
+ addrinfo = yield self.resolver.resolve(host, port, af)
connector = _Connector(
addrinfo,
functools.partial(self._create_stream, max_buffer_size,
source_ip=source_ip, source_port=source_port)
)
- af, addr, stream = yield connector.start()
+ af, addr, stream = yield connector.start(connect_timeout=timeout)
# TODO: For better performance we could cache the (af, addr)
# information here and re-use it on subsequent connections to
# the same host. (http://tools.ietf.org/html/rfc6555#section-4.2)
if ssl_options is not None:
- stream = yield stream.start_tls(False, ssl_options=ssl_options,
- server_hostname=host)
+ if timeout is not None:
+ stream = yield gen.with_timeout(timeout, stream.start_tls(
+ False, ssl_options=ssl_options, server_hostname=host))
+ else:
+ stream = yield stream.start_tls(False, ssl_options=ssl_options,
+ server_hostname=host)
raise gen.Return(stream)
def _create_stream(self, max_buffer_size, af, addr, source_ip=None,
fu.set_exception(e)
return fu
else:
- return stream.connect(addr)
+ return stream, stream.connect(addr)
from tornado.tcpclient import TCPClient, _Connector
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, gen_test
-from tornado.test.util import skipIfNoIPv6, unittest, refusing_port, skipIfNonUnix
+from tornado.test.util import skipIfNoIPv6, unittest, refusing_port, skipIfNonUnix, skipOnTravis
+from tornado.gen import TimeoutError
# Fake address families for testing. Used in place of AF_INET
# and AF_INET6 because some installations do not have AF_INET6.
'127.0.0.1',
source_port=1)
+ @gen_test
+ def test_connect_timeout(self):
+ timeout = 0.05
+
+ class TimeoutResolver(Resolver):
+ def resolve(self, *args, **kwargs):
+ return Future() # never completes
+ with self.assertRaises(TimeoutError):
+ yield TCPClient(resolver=TimeoutResolver()).connect(
+ '1.2.3.4', 12345, timeout=timeout)
+
class TestConnectorSplit(unittest.TestCase):
def test_one_family(self):
super(ConnectorTest, self).tearDown()
def create_stream(self, af, addr):
+ stream = ConnectorTest.FakeStream()
+ self.streams[addr] = stream
future = Future()
self.connect_futures[(af, addr)] = future
- return future
+ return stream, future
def assert_pending(self, *keys):
self.assertEqual(sorted(self.connect_futures.keys()), sorted(keys))
def resolve_connect(self, af, addr, success):
future = self.connect_futures.pop((af, addr))
if success:
- self.streams[addr] = ConnectorTest.FakeStream()
future.set_result(self.streams[addr])
else:
+ self.streams.pop(addr)
future.set_exception(IOError())
+ def assert_connector_streams_closed(self, conn):
+ for stream in conn.streams:
+ self.assertTrue(stream.closed)
+
def start_connect(self, addrinfo):
conn = _Connector(addrinfo, self.create_stream)
# Give it a huge timeout; we'll trigger timeouts manually.
- future = conn.start(3600)
+ future = conn.start(3600, connect_timeout=self.io_loop.time() + 3600)
return conn, future
def test_immediate_success(self):
self.assertFalse(future.done())
self.resolve_connect(AF1, 'b', False)
self.assertRaises(IOError, future.result)
+
+ def test_one_family_timeout_after_connect_timeout(self):
+ conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
+ self.assert_pending((AF1, 'a'))
+ conn.on_connect_timeout()
+ # the connector will close all streams on connect timeout, we
+ # should explicitly pop the connect_future.
+ self.connect_futures.pop((AF1, 'a'))
+ self.assertTrue(self.streams.pop('a').closed)
+ conn.on_timeout()
+ # if the future is set with TimeoutError, we will not iterate next
+ # possible address.
+ self.assert_pending()
+ self.assertEqual(len(conn.streams), 1)
+ self.assert_connector_streams_closed(conn)
+ self.assertRaises(TimeoutError, future.result)
+
+ def test_one_family_success_before_connect_timeout(self):
+ conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
+ self.assert_pending((AF1, 'a'))
+ self.resolve_connect(AF1, 'a', True)
+ conn.on_connect_timeout()
+ self.assert_pending()
+ self.assertEqual(self.streams['a'].closed, False)
+ # success stream will be pop
+ self.assertEqual(len(conn.streams), 0)
+ # streams in connector should be closed after connect timeout
+ self.assert_connector_streams_closed(conn)
+ self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
+
+ def test_one_family_second_try_after_connect_timeout(self):
+ conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
+ self.assert_pending((AF1, 'a'))
+ self.resolve_connect(AF1, 'a', False)
+ self.assert_pending((AF1, 'b'))
+ conn.on_connect_timeout()
+ self.connect_futures.pop((AF1, 'b'))
+ self.assertTrue(self.streams.pop('b').closed)
+ self.assert_pending()
+ self.assertEqual(len(conn.streams), 2)
+ self.assert_connector_streams_closed(conn)
+ self.assertRaises(TimeoutError, future.result)
+
+ def test_one_family_second_try_failure_before_connect_timeout(self):
+ conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
+ self.assert_pending((AF1, 'a'))
+ self.resolve_connect(AF1, 'a', False)
+ self.assert_pending((AF1, 'b'))
+ self.resolve_connect(AF1, 'b', False)
+ conn.on_connect_timeout()
+ self.assert_pending()
+ self.assertEqual(len(conn.streams), 2)
+ self.assert_connector_streams_closed(conn)
+ self.assertRaises(IOError, future.result)
+
+ def test_two_family_timeout_before_connect_timeout(self):
+ conn, future = self.start_connect(self.addrinfo)
+ self.assert_pending((AF1, 'a'))
+ conn.on_timeout()
+ self.assert_pending((AF1, 'a'), (AF2, 'c'))
+ conn.on_connect_timeout()
+ self.connect_futures.pop((AF1, 'a'))
+ self.assertTrue(self.streams.pop('a').closed)
+ self.connect_futures.pop((AF2, 'c'))
+ self.assertTrue(self.streams.pop('c').closed)
+ self.assert_pending()
+ self.assertEqual(len(conn.streams), 2)
+ self.assert_connector_streams_closed(conn)
+ self.assertRaises(TimeoutError, future.result)
+
+ def test_two_family_success_after_timeout(self):
+ conn, future = self.start_connect(self.addrinfo)
+ self.assert_pending((AF1, 'a'))
+ conn.on_timeout()
+ self.assert_pending((AF1, 'a'), (AF2, 'c'))
+ self.resolve_connect(AF1, 'a', True)
+ # if one of streams succeed, connector will close all other streams
+ self.connect_futures.pop((AF2, 'c'))
+ self.assertTrue(self.streams.pop('c').closed)
+ self.assert_pending()
+ self.assertEqual(len(conn.streams), 1)
+ self.assert_connector_streams_closed(conn)
+ self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
+
+ def test_two_family_timeout_after_connect_timeout(self):
+ conn, future = self.start_connect(self.addrinfo)
+ self.assert_pending((AF1, 'a'))
+ conn.on_connect_timeout()
+ self.connect_futures.pop((AF1, 'a'))
+ self.assertTrue(self.streams.pop('a').closed)
+ self.assert_pending()
+ conn.on_timeout()
+ # if the future is set with TimeoutError, connector will not
+ # trigger secondary address.
+ self.assert_pending()
+ self.assertEqual(len(conn.streams), 1)
+ self.assert_connector_streams_closed(conn)
+ self.assertRaises(TimeoutError, future.result)