From 3cc3e1332aba7779c273300e2beea198f07d002e Mon Sep 17 00:00:00 2001 From: Lancher Date: Sun, 16 Jul 2017 11:13:00 +0800 Subject: [PATCH] tcpclient: Introduce connect timeout (#2094) Fixes #1219 --- tornado/tcpclient.py | 77 ++++++++++++++++---- tornado/test/tcpclient_test.py | 124 +++++++++++++++++++++++++++++++-- 2 files changed, 185 insertions(+), 16 deletions(-) diff --git a/tornado/tcpclient.py b/tornado/tcpclient.py index f8e0d019c..6d5355b32 100644 --- a/tornado/tcpclient.py +++ b/tornado/tcpclient.py @@ -20,6 +20,9 @@ from __future__ import absolute_import, division, print_function import functools import socket +import time +import numbers +import datetime from tornado.concurrent import Future from tornado.ioloop import IOLoop @@ -27,6 +30,8 @@ from tornado.iostream import IOStream from tornado import gen from tornado.netutil import Resolver from tornado.platform.auto import set_close_exec +from tornado.gen import TimeoutError +from tornado.util import timedelta_to_seconds _INITIAL_CONNECT_TIMEOUT = 0.3 @@ -54,9 +59,11 @@ class _Connector(object): self.future = Future() self.timeout = None + self.connect_timeout = None self.last_error = None self.remaining = len(addrinfo) self.primary_addrs, self.secondary_addrs = self.split(addrinfo) + self.streams = set() @staticmethod def split(addrinfo): @@ -78,9 +85,11 @@ class _Connector(object): secondary.append((af, addr)) return primary, secondary - def start(self, timeout=_INITIAL_CONNECT_TIMEOUT): + def start(self, timeout=_INITIAL_CONNECT_TIMEOUT, connect_timeout=None): self.try_connect(iter(self.primary_addrs)) - self.set_timout(timeout) + self.set_timeout(timeout) + if connect_timeout is not None: + self.set_connect_timeout(connect_timeout) return self.future def try_connect(self, addrs): @@ -94,7 +103,8 @@ class _Connector(object): self.future.set_exception(self.last_error or IOError("connection failed")) return - future = self.connect(af, addr) + stream, future = self.connect(af, addr) + self.streams.add(stream) future.add_done_callback(functools.partial(self.on_connect_done, addrs, af, addr)) @@ -115,25 +125,47 @@ class _Connector(object): self.io_loop.remove_timeout(self.timeout) self.on_timeout() return - self.clear_timeout() + self.clear_timeouts() if self.future.done(): # This is a late arrival; just drop it. stream.close() else: + self.streams.discard(stream) self.future.set_result((af, addr, stream)) + self.close_streams() - def set_timout(self, timeout): + def set_timeout(self, timeout): self.timeout = self.io_loop.add_timeout(self.io_loop.time() + timeout, self.on_timeout) def on_timeout(self): self.timeout = None - self.try_connect(iter(self.secondary_addrs)) + if not self.future.done(): + self.try_connect(iter(self.secondary_addrs)) def clear_timeout(self): if self.timeout is not None: self.io_loop.remove_timeout(self.timeout) + def set_connect_timeout(self, connect_timeout): + self.connect_timeout = self.io_loop.add_timeout( + connect_timeout, self.on_connect_timeout) + + def on_connect_timeout(self): + if not self.future.done(): + self.future.set_exception(TimeoutError()) + self.close_streams() + + def clear_timeouts(self): + if self.timeout is not None: + self.io_loop.remove_timeout(self.timeout) + if self.connect_timeout is not None: + self.io_loop.remove_timeout(self.connect_timeout) + + def close_streams(self): + for stream in self.streams: + stream.close() + class TCPClient(object): """A non-blocking TCP connection factory. @@ -155,7 +187,8 @@ class TCPClient(object): @gen.coroutine def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None, - max_buffer_size=None, source_ip=None, source_port=None): + max_buffer_size=None, source_ip=None, source_port=None, + timeout=None): """Connect to the given host and port. Asynchronously returns an `.IOStream` (or `.SSLIOStream` if @@ -167,25 +200,45 @@ class TCPClient(object): use a specific interface, it has to be handled outside of Tornado as this depends very much on the platform. + Raises `TimeoutError` if the input future does not complete before + ``timeout``, which may be specified in any form allowed by + `.IOLoop.add_timeout` (i.e. a `datetime.timedelta` or an absolute time + relative to `.IOLoop.time`) + Similarly, when the user requires a certain source port, it can be specified using the ``source_port`` arg. .. versionchanged:: 4.5 Added the ``source_ip`` and ``source_port`` arguments. """ - addrinfo = yield self.resolver.resolve(host, port, af) + if timeout is not None: + if isinstance(timeout, numbers.Real): + timeout = IOLoop.current().time() + timeout + elif isinstance(timeout, datetime.timedelta): + timeout = IOLoop.current().time() + timedelta_to_seconds(timeout) + else: + raise TypeError("Unsupported timeout %r" % timeout) + if timeout is not None: + addrinfo = yield gen.with_timeout( + timeout, self.resolver.resolve(host, port, af)) + else: + addrinfo = yield self.resolver.resolve(host, port, af) connector = _Connector( addrinfo, functools.partial(self._create_stream, max_buffer_size, source_ip=source_ip, source_port=source_port) ) - af, addr, stream = yield connector.start() + af, addr, stream = yield connector.start(connect_timeout=timeout) # TODO: For better performance we could cache the (af, addr) # information here and re-use it on subsequent connections to # the same host. (http://tools.ietf.org/html/rfc6555#section-4.2) if ssl_options is not None: - stream = yield stream.start_tls(False, ssl_options=ssl_options, - server_hostname=host) + if timeout is not None: + stream = yield gen.with_timeout(timeout, stream.start_tls( + False, ssl_options=ssl_options, server_hostname=host)) + else: + stream = yield stream.start_tls(False, ssl_options=ssl_options, + server_hostname=host) raise gen.Return(stream) def _create_stream(self, max_buffer_size, af, addr, source_ip=None, @@ -219,4 +272,4 @@ class TCPClient(object): fu.set_exception(e) return fu else: - return stream.connect(addr) + return stream, stream.connect(addr) diff --git a/tornado/test/tcpclient_test.py b/tornado/test/tcpclient_test.py index 117f28de1..3c3abd1f1 100644 --- a/tornado/test/tcpclient_test.py +++ b/tornado/test/tcpclient_test.py @@ -26,7 +26,8 @@ from tornado.queues import Queue from tornado.tcpclient import TCPClient, _Connector from tornado.tcpserver import TCPServer from tornado.testing import AsyncTestCase, gen_test -from tornado.test.util import skipIfNoIPv6, unittest, refusing_port, skipIfNonUnix +from tornado.test.util import skipIfNoIPv6, unittest, refusing_port, skipIfNonUnix, skipOnTravis +from tornado.gen import TimeoutError # Fake address families for testing. Used in place of AF_INET # and AF_INET6 because some installations do not have AF_INET6. @@ -158,6 +159,17 @@ class TCPClientTest(AsyncTestCase): '127.0.0.1', source_port=1) + @gen_test + def test_connect_timeout(self): + timeout = 0.05 + + class TimeoutResolver(Resolver): + def resolve(self, *args, **kwargs): + return Future() # never completes + with self.assertRaises(TimeoutError): + yield TCPClient(resolver=TimeoutResolver()).connect( + '1.2.3.4', 12345, timeout=timeout) + class TestConnectorSplit(unittest.TestCase): def test_one_family(self): @@ -202,9 +214,11 @@ class ConnectorTest(AsyncTestCase): super(ConnectorTest, self).tearDown() def create_stream(self, af, addr): + stream = ConnectorTest.FakeStream() + self.streams[addr] = stream future = Future() self.connect_futures[(af, addr)] = future - return future + return stream, future def assert_pending(self, *keys): self.assertEqual(sorted(self.connect_futures.keys()), sorted(keys)) @@ -212,15 +226,19 @@ class ConnectorTest(AsyncTestCase): def resolve_connect(self, af, addr, success): future = self.connect_futures.pop((af, addr)) if success: - self.streams[addr] = ConnectorTest.FakeStream() future.set_result(self.streams[addr]) else: + self.streams.pop(addr) future.set_exception(IOError()) + def assert_connector_streams_closed(self, conn): + for stream in conn.streams: + self.assertTrue(stream.closed) + def start_connect(self, addrinfo): conn = _Connector(addrinfo, self.create_stream) # Give it a huge timeout; we'll trigger timeouts manually. - future = conn.start(3600) + future = conn.start(3600, connect_timeout=self.io_loop.time() + 3600) return conn, future def test_immediate_success(self): @@ -311,3 +329,101 @@ class ConnectorTest(AsyncTestCase): self.assertFalse(future.done()) self.resolve_connect(AF1, 'b', False) self.assertRaises(IOError, future.result) + + def test_one_family_timeout_after_connect_timeout(self): + conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')]) + self.assert_pending((AF1, 'a')) + conn.on_connect_timeout() + # the connector will close all streams on connect timeout, we + # should explicitly pop the connect_future. + self.connect_futures.pop((AF1, 'a')) + self.assertTrue(self.streams.pop('a').closed) + conn.on_timeout() + # if the future is set with TimeoutError, we will not iterate next + # possible address. + self.assert_pending() + self.assertEqual(len(conn.streams), 1) + self.assert_connector_streams_closed(conn) + self.assertRaises(TimeoutError, future.result) + + def test_one_family_success_before_connect_timeout(self): + conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')]) + self.assert_pending((AF1, 'a')) + self.resolve_connect(AF1, 'a', True) + conn.on_connect_timeout() + self.assert_pending() + self.assertEqual(self.streams['a'].closed, False) + # success stream will be pop + self.assertEqual(len(conn.streams), 0) + # streams in connector should be closed after connect timeout + self.assert_connector_streams_closed(conn) + self.assertEqual(future.result(), (AF1, 'a', self.streams['a'])) + + def test_one_family_second_try_after_connect_timeout(self): + conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')]) + self.assert_pending((AF1, 'a')) + self.resolve_connect(AF1, 'a', False) + self.assert_pending((AF1, 'b')) + conn.on_connect_timeout() + self.connect_futures.pop((AF1, 'b')) + self.assertTrue(self.streams.pop('b').closed) + self.assert_pending() + self.assertEqual(len(conn.streams), 2) + self.assert_connector_streams_closed(conn) + self.assertRaises(TimeoutError, future.result) + + def test_one_family_second_try_failure_before_connect_timeout(self): + conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')]) + self.assert_pending((AF1, 'a')) + self.resolve_connect(AF1, 'a', False) + self.assert_pending((AF1, 'b')) + self.resolve_connect(AF1, 'b', False) + conn.on_connect_timeout() + self.assert_pending() + self.assertEqual(len(conn.streams), 2) + self.assert_connector_streams_closed(conn) + self.assertRaises(IOError, future.result) + + def test_two_family_timeout_before_connect_timeout(self): + conn, future = self.start_connect(self.addrinfo) + self.assert_pending((AF1, 'a')) + conn.on_timeout() + self.assert_pending((AF1, 'a'), (AF2, 'c')) + conn.on_connect_timeout() + self.connect_futures.pop((AF1, 'a')) + self.assertTrue(self.streams.pop('a').closed) + self.connect_futures.pop((AF2, 'c')) + self.assertTrue(self.streams.pop('c').closed) + self.assert_pending() + self.assertEqual(len(conn.streams), 2) + self.assert_connector_streams_closed(conn) + self.assertRaises(TimeoutError, future.result) + + def test_two_family_success_after_timeout(self): + conn, future = self.start_connect(self.addrinfo) + self.assert_pending((AF1, 'a')) + conn.on_timeout() + self.assert_pending((AF1, 'a'), (AF2, 'c')) + self.resolve_connect(AF1, 'a', True) + # if one of streams succeed, connector will close all other streams + self.connect_futures.pop((AF2, 'c')) + self.assertTrue(self.streams.pop('c').closed) + self.assert_pending() + self.assertEqual(len(conn.streams), 1) + self.assert_connector_streams_closed(conn) + self.assertEqual(future.result(), (AF1, 'a', self.streams['a'])) + + def test_two_family_timeout_after_connect_timeout(self): + conn, future = self.start_connect(self.addrinfo) + self.assert_pending((AF1, 'a')) + conn.on_connect_timeout() + self.connect_futures.pop((AF1, 'a')) + self.assertTrue(self.streams.pop('a').closed) + self.assert_pending() + conn.on_timeout() + # if the future is set with TimeoutError, connector will not + # trigger secondary address. + self.assert_pending() + self.assertEqual(len(conn.streams), 1) + self.assert_connector_streams_closed(conn) + self.assertRaises(TimeoutError, future.result) -- 2.47.2