]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
tcpclient: Introduce connect timeout (#2094)
authorLancher <Steve.LiuShiHao@gmail.com>
Sun, 16 Jul 2017 03:13:00 +0000 (11:13 +0800)
committerBen Darnell <ben@bendarnell.com>
Sun, 16 Jul 2017 03:13:00 +0000 (23:13 -0400)
Fixes #1219

tornado/tcpclient.py
tornado/test/tcpclient_test.py

index f8e0d019cadeaf752df1bd3b502301f505f654ad..6d5355b3250c7bb22d8ea6e6044171b274ec3155 100644 (file)
@@ -20,6 +20,9 @@ from __future__ import absolute_import, division, print_function
 
 import functools
 import socket
+import time
+import numbers
+import datetime
 
 from tornado.concurrent import Future
 from tornado.ioloop import IOLoop
@@ -27,6 +30,8 @@ from tornado.iostream import IOStream
 from tornado import gen
 from tornado.netutil import Resolver
 from tornado.platform.auto import set_close_exec
+from tornado.gen import TimeoutError
+from tornado.util import timedelta_to_seconds
 
 _INITIAL_CONNECT_TIMEOUT = 0.3
 
@@ -54,9 +59,11 @@ class _Connector(object):
 
         self.future = Future()
         self.timeout = None
+        self.connect_timeout = None
         self.last_error = None
         self.remaining = len(addrinfo)
         self.primary_addrs, self.secondary_addrs = self.split(addrinfo)
+        self.streams = set()
 
     @staticmethod
     def split(addrinfo):
@@ -78,9 +85,11 @@ class _Connector(object):
                 secondary.append((af, addr))
         return primary, secondary
 
-    def start(self, timeout=_INITIAL_CONNECT_TIMEOUT):
+    def start(self, timeout=_INITIAL_CONNECT_TIMEOUT, connect_timeout=None):
         self.try_connect(iter(self.primary_addrs))
-        self.set_timout(timeout)
+        self.set_timeout(timeout)
+        if connect_timeout is not None:
+            self.set_connect_timeout(connect_timeout)
         return self.future
 
     def try_connect(self, addrs):
@@ -94,7 +103,8 @@ class _Connector(object):
                 self.future.set_exception(self.last_error or
                                           IOError("connection failed"))
             return
-        future = self.connect(af, addr)
+        stream, future = self.connect(af, addr)
+        self.streams.add(stream)
         future.add_done_callback(functools.partial(self.on_connect_done,
                                                    addrs, af, addr))
 
@@ -115,25 +125,47 @@ class _Connector(object):
                 self.io_loop.remove_timeout(self.timeout)
                 self.on_timeout()
             return
-        self.clear_timeout()
+        self.clear_timeouts()
         if self.future.done():
             # This is a late arrival; just drop it.
             stream.close()
         else:
+            self.streams.discard(stream)
             self.future.set_result((af, addr, stream))
+            self.close_streams()
 
-    def set_timout(self, timeout):
+    def set_timeout(self, timeout):
         self.timeout = self.io_loop.add_timeout(self.io_loop.time() + timeout,
                                                 self.on_timeout)
 
     def on_timeout(self):
         self.timeout = None
-        self.try_connect(iter(self.secondary_addrs))
+        if not self.future.done():
+            self.try_connect(iter(self.secondary_addrs))
 
     def clear_timeout(self):
         if self.timeout is not None:
             self.io_loop.remove_timeout(self.timeout)
 
+    def set_connect_timeout(self, connect_timeout):
+        self.connect_timeout = self.io_loop.add_timeout(
+            connect_timeout, self.on_connect_timeout)
+
+    def on_connect_timeout(self):
+        if not self.future.done():
+            self.future.set_exception(TimeoutError())
+        self.close_streams()
+
+    def clear_timeouts(self):
+        if self.timeout is not None:
+            self.io_loop.remove_timeout(self.timeout)
+        if self.connect_timeout is not None:
+            self.io_loop.remove_timeout(self.connect_timeout)
+
+    def close_streams(self):
+        for stream in self.streams:
+            stream.close()
+
 
 class TCPClient(object):
     """A non-blocking TCP connection factory.
@@ -155,7 +187,8 @@ class TCPClient(object):
 
     @gen.coroutine
     def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None,
-                max_buffer_size=None, source_ip=None, source_port=None):
+                max_buffer_size=None, source_ip=None, source_port=None,
+                timeout=None):
         """Connect to the given host and port.
 
         Asynchronously returns an `.IOStream` (or `.SSLIOStream` if
@@ -167,25 +200,45 @@ class TCPClient(object):
         use a specific interface, it has to be handled outside
         of Tornado as this depends very much on the platform.
 
+        Raises `TimeoutError` if the input future does not complete before
+        ``timeout``, which may be specified in any form allowed by
+        `.IOLoop.add_timeout` (i.e. a `datetime.timedelta` or an absolute time
+        relative to `.IOLoop.time`)
+
         Similarly, when the user requires a certain source port, it can
         be specified using the ``source_port`` arg.
 
         .. versionchanged:: 4.5
            Added the ``source_ip`` and ``source_port`` arguments.
         """
-        addrinfo = yield self.resolver.resolve(host, port, af)
+        if timeout is not None:
+            if isinstance(timeout, numbers.Real):
+                timeout = IOLoop.current().time() + timeout
+            elif isinstance(timeout, datetime.timedelta):
+                timeout = IOLoop.current().time() + timedelta_to_seconds(timeout)
+            else:
+                raise TypeError("Unsupported timeout %r" % timeout)
+        if timeout is not None:
+            addrinfo = yield gen.with_timeout(
+                timeout, self.resolver.resolve(host, port, af))
+        else:
+            addrinfo = yield self.resolver.resolve(host, port, af)
         connector = _Connector(
             addrinfo,
             functools.partial(self._create_stream, max_buffer_size,
                               source_ip=source_ip, source_port=source_port)
         )
-        af, addr, stream = yield connector.start()
+        af, addr, stream = yield connector.start(connect_timeout=timeout)
         # TODO: For better performance we could cache the (af, addr)
         # information here and re-use it on subsequent connections to
         # the same host. (http://tools.ietf.org/html/rfc6555#section-4.2)
         if ssl_options is not None:
-            stream = yield stream.start_tls(False, ssl_options=ssl_options,
-                                            server_hostname=host)
+            if timeout is not None:
+                stream = yield gen.with_timeout(timeout, stream.start_tls(
+                    False, ssl_options=ssl_options, server_hostname=host))
+            else:
+                stream = yield stream.start_tls(False, ssl_options=ssl_options,
+                                                server_hostname=host)
         raise gen.Return(stream)
 
     def _create_stream(self, max_buffer_size, af, addr, source_ip=None,
@@ -219,4 +272,4 @@ class TCPClient(object):
             fu.set_exception(e)
             return fu
         else:
-            return stream.connect(addr)
+            return stream, stream.connect(addr)
index 117f28de1319f4d199e77191f31ce4f7a0137ead..3c3abd1f172a058d238c2c37fcfe560de9ed0aad 100644 (file)
@@ -26,7 +26,8 @@ from tornado.queues import Queue
 from tornado.tcpclient import TCPClient, _Connector
 from tornado.tcpserver import TCPServer
 from tornado.testing import AsyncTestCase, gen_test
-from tornado.test.util import skipIfNoIPv6, unittest, refusing_port, skipIfNonUnix
+from tornado.test.util import skipIfNoIPv6, unittest, refusing_port, skipIfNonUnix, skipOnTravis
+from tornado.gen import TimeoutError
 
 # Fake address families for testing.  Used in place of AF_INET
 # and AF_INET6 because some installations do not have AF_INET6.
@@ -158,6 +159,17 @@ class TCPClientTest(AsyncTestCase):
                           '127.0.0.1',
                           source_port=1)
 
+    @gen_test
+    def test_connect_timeout(self):
+        timeout = 0.05
+
+        class TimeoutResolver(Resolver):
+            def resolve(self, *args, **kwargs):
+                return Future()  # never completes
+        with self.assertRaises(TimeoutError):
+            yield TCPClient(resolver=TimeoutResolver()).connect(
+                '1.2.3.4', 12345, timeout=timeout)
+
 
 class TestConnectorSplit(unittest.TestCase):
     def test_one_family(self):
@@ -202,9 +214,11 @@ class ConnectorTest(AsyncTestCase):
         super(ConnectorTest, self).tearDown()
 
     def create_stream(self, af, addr):
+        stream = ConnectorTest.FakeStream()
+        self.streams[addr] = stream
         future = Future()
         self.connect_futures[(af, addr)] = future
-        return future
+        return stream, future
 
     def assert_pending(self, *keys):
         self.assertEqual(sorted(self.connect_futures.keys()), sorted(keys))
@@ -212,15 +226,19 @@ class ConnectorTest(AsyncTestCase):
     def resolve_connect(self, af, addr, success):
         future = self.connect_futures.pop((af, addr))
         if success:
-            self.streams[addr] = ConnectorTest.FakeStream()
             future.set_result(self.streams[addr])
         else:
+            self.streams.pop(addr)
             future.set_exception(IOError())
 
+    def assert_connector_streams_closed(self, conn):
+        for stream in conn.streams:
+            self.assertTrue(stream.closed)
+
     def start_connect(self, addrinfo):
         conn = _Connector(addrinfo, self.create_stream)
         # Give it a huge timeout; we'll trigger timeouts manually.
-        future = conn.start(3600)
+        future = conn.start(3600, connect_timeout=self.io_loop.time() + 3600)
         return conn, future
 
     def test_immediate_success(self):
@@ -311,3 +329,101 @@ class ConnectorTest(AsyncTestCase):
         self.assertFalse(future.done())
         self.resolve_connect(AF1, 'b', False)
         self.assertRaises(IOError, future.result)
+
+    def test_one_family_timeout_after_connect_timeout(self):
+        conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
+        self.assert_pending((AF1, 'a'))
+        conn.on_connect_timeout()
+        # the connector will close all streams on connect timeout, we
+        # should explicitly pop the connect_future.
+        self.connect_futures.pop((AF1, 'a'))
+        self.assertTrue(self.streams.pop('a').closed)
+        conn.on_timeout()
+        # if the future is set with TimeoutError, we will not iterate next
+        # possible address.
+        self.assert_pending()
+        self.assertEqual(len(conn.streams), 1)
+        self.assert_connector_streams_closed(conn)
+        self.assertRaises(TimeoutError, future.result)
+
+    def test_one_family_success_before_connect_timeout(self):
+        conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
+        self.assert_pending((AF1, 'a'))
+        self.resolve_connect(AF1, 'a', True)
+        conn.on_connect_timeout()
+        self.assert_pending()
+        self.assertEqual(self.streams['a'].closed, False)
+        # success stream will be pop
+        self.assertEqual(len(conn.streams), 0)
+        # streams in connector should be closed after connect timeout
+        self.assert_connector_streams_closed(conn)
+        self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
+
+    def test_one_family_second_try_after_connect_timeout(self):
+        conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
+        self.assert_pending((AF1, 'a'))
+        self.resolve_connect(AF1, 'a', False)
+        self.assert_pending((AF1, 'b'))
+        conn.on_connect_timeout()
+        self.connect_futures.pop((AF1, 'b'))
+        self.assertTrue(self.streams.pop('b').closed)
+        self.assert_pending()
+        self.assertEqual(len(conn.streams), 2)
+        self.assert_connector_streams_closed(conn)
+        self.assertRaises(TimeoutError, future.result)
+
+    def test_one_family_second_try_failure_before_connect_timeout(self):
+        conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
+        self.assert_pending((AF1, 'a'))
+        self.resolve_connect(AF1, 'a', False)
+        self.assert_pending((AF1, 'b'))
+        self.resolve_connect(AF1, 'b', False)
+        conn.on_connect_timeout()
+        self.assert_pending()
+        self.assertEqual(len(conn.streams), 2)
+        self.assert_connector_streams_closed(conn)
+        self.assertRaises(IOError, future.result)
+
+    def test_two_family_timeout_before_connect_timeout(self):
+        conn, future = self.start_connect(self.addrinfo)
+        self.assert_pending((AF1, 'a'))
+        conn.on_timeout()
+        self.assert_pending((AF1, 'a'), (AF2, 'c'))
+        conn.on_connect_timeout()
+        self.connect_futures.pop((AF1, 'a'))
+        self.assertTrue(self.streams.pop('a').closed)
+        self.connect_futures.pop((AF2, 'c'))
+        self.assertTrue(self.streams.pop('c').closed)
+        self.assert_pending()
+        self.assertEqual(len(conn.streams), 2)
+        self.assert_connector_streams_closed(conn)
+        self.assertRaises(TimeoutError, future.result)
+
+    def test_two_family_success_after_timeout(self):
+        conn, future = self.start_connect(self.addrinfo)
+        self.assert_pending((AF1, 'a'))
+        conn.on_timeout()
+        self.assert_pending((AF1, 'a'), (AF2, 'c'))
+        self.resolve_connect(AF1, 'a', True)
+        # if one of streams succeed, connector will close all other streams
+        self.connect_futures.pop((AF2, 'c'))
+        self.assertTrue(self.streams.pop('c').closed)
+        self.assert_pending()
+        self.assertEqual(len(conn.streams), 1)
+        self.assert_connector_streams_closed(conn)
+        self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
+
+    def test_two_family_timeout_after_connect_timeout(self):
+        conn, future = self.start_connect(self.addrinfo)
+        self.assert_pending((AF1, 'a'))
+        conn.on_connect_timeout()
+        self.connect_futures.pop((AF1, 'a'))
+        self.assertTrue(self.streams.pop('a').closed)
+        self.assert_pending()
+        conn.on_timeout()
+        # if the future is set with TimeoutError, connector will not
+        # trigger secondary address.
+        self.assert_pending()
+        self.assertEqual(len(conn.streams), 1)
+        self.assert_connector_streams_closed(conn)
+        self.assertRaises(TimeoutError, future.result)