]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
simple_httpclient: handle connect_timeout or request_timeout of 0
authorPierce Lopez <pierce.lopez@gmail.com>
Tue, 22 Sep 2020 19:43:43 +0000 (15:43 -0400)
committerPierce Lopez <pierce.lopez@gmail.com>
Fri, 25 Sep 2020 20:16:48 +0000 (16:16 -0400)
Using a connect_timeout or request_timeout of 0 was effectively
invalid for simple_httpclient: it would skip the actual request
entirely (because the bulk of the logic was inside "if timeout:").
This was not checked for or raised as an error, it just behaved
unexpectedly.

Change simple_httpclient to always assert these timeouts are not None
and to support the 0 value similar to curl (where request_timeout=0
means no timeout, and connect_timeout=0 means curl default 300 seconds
which is very very long for a tcp connection).

tornado/simple_httpclient.py

index c977aeefdee6ec58561b7b099f190043fa98b37a..f99f391fdc213e453415612769870593cd54befa 100644 (file)
@@ -167,16 +167,20 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
     ) -> None:
         key = object()
         self.queue.append((key, request, callback))
-        if not len(self.active) < self.max_clients:
-            assert request.connect_timeout is not None
-            assert request.request_timeout is not None
-            timeout_handle = self.io_loop.add_timeout(
-                self.io_loop.time()
-                + min(request.connect_timeout, request.request_timeout),
-                functools.partial(self._on_timeout, key, "in request queue"),
-            )
-        else:
-            timeout_handle = None
+        assert request.connect_timeout is not None
+        assert request.request_timeout is not None
+        timeout_handle = None
+        if len(self.active) >= self.max_clients:
+            timeout = (
+                min(request.connect_timeout, request.request_timeout)
+                or request.connect_timeout
+                or request.request_timeout
+            )  # min but skip zero
+            if timeout:
+                timeout_handle = self.io_loop.add_timeout(
+                    self.io_loop.time() + timeout,
+                    functools.partial(self._on_timeout, key, "in request queue"),
+                )
         self.waiting[key] = (request, callback, timeout_handle)
         self._process_queue()
         if self.queue:
@@ -321,123 +325,123 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
                         % (self.request.network_interface,)
                     )
 
-            timeout = min(self.request.connect_timeout, self.request.request_timeout)
+            timeout = (
+                min(self.request.connect_timeout, self.request.request_timeout)
+                or self.request.connect_timeout
+                or self.request.request_timeout
+            )  # min but skip zero
             if timeout:
                 self._timeout = self.io_loop.add_timeout(
                     self.start_time + timeout,
                     functools.partial(self._on_timeout, "while connecting"),
                 )
-                stream = await self.tcp_client.connect(
-                    host,
-                    port,
-                    af=af,
-                    ssl_options=ssl_options,
-                    max_buffer_size=self.max_buffer_size,
-                    source_ip=source_ip,
-                )
+            stream = await self.tcp_client.connect(
+                host,
+                port,
+                af=af,
+                ssl_options=ssl_options,
+                max_buffer_size=self.max_buffer_size,
+                source_ip=source_ip,
+            )
 
-                if self.final_callback is None:
-                    # final_callback is cleared if we've hit our timeout.
-                    stream.close()
-                    return
-                self.stream = stream
-                self.stream.set_close_callback(self.on_connection_close)
-                self._remove_timeout()
-                if self.final_callback is None:
-                    return
-                if self.request.request_timeout:
-                    self._timeout = self.io_loop.add_timeout(
-                        self.start_time + self.request.request_timeout,
-                        functools.partial(self._on_timeout, "during request"),
-                    )
-                if (
-                    self.request.method not in self._SUPPORTED_METHODS
-                    and not self.request.allow_nonstandard_methods
-                ):
-                    raise KeyError("unknown method %s" % self.request.method)
-                for key in (
-                    "proxy_host",
-                    "proxy_port",
-                    "proxy_username",
-                    "proxy_password",
-                    "proxy_auth_mode",
-                ):
-                    if getattr(self.request, key, None):
-                        raise NotImplementedError("%s not supported" % key)
-                if "Connection" not in self.request.headers:
-                    self.request.headers["Connection"] = "close"
-                if "Host" not in self.request.headers:
-                    if "@" in self.parsed.netloc:
-                        self.request.headers["Host"] = self.parsed.netloc.rpartition(
-                            "@"
-                        )[-1]
-                    else:
-                        self.request.headers["Host"] = self.parsed.netloc
-                username, password = None, None
-                if self.parsed.username is not None:
-                    username, password = self.parsed.username, self.parsed.password
-                elif self.request.auth_username is not None:
-                    username = self.request.auth_username
-                    password = self.request.auth_password or ""
-                if username is not None:
-                    assert password is not None
-                    if self.request.auth_mode not in (None, "basic"):
-                        raise ValueError(
-                            "unsupported auth_mode %s", self.request.auth_mode
-                        )
-                    self.request.headers["Authorization"] = "Basic " + _unicode(
-                        base64.b64encode(
-                            httputil.encode_username_password(username, password)
-                        )
-                    )
-                if self.request.user_agent:
-                    self.request.headers["User-Agent"] = self.request.user_agent
-                elif self.request.headers.get("User-Agent") is None:
-                    self.request.headers["User-Agent"] = "Tornado/{}".format(version)
-                if not self.request.allow_nonstandard_methods:
-                    # Some HTTP methods nearly always have bodies while others
-                    # almost never do. Fail in this case unless the user has
-                    # opted out of sanity checks with allow_nonstandard_methods.
-                    body_expected = self.request.method in ("POST", "PATCH", "PUT")
-                    body_present = (
-                        self.request.body is not None
-                        or self.request.body_producer is not None
+            if self.final_callback is None:
+                # final_callback is cleared if we've hit our timeout.
+                stream.close()
+                return
+            self.stream = stream
+            self.stream.set_close_callback(self.on_connection_close)
+            self._remove_timeout()
+            if self.final_callback is None:
+                return
+            if self.request.request_timeout:
+                self._timeout = self.io_loop.add_timeout(
+                    self.start_time + self.request.request_timeout,
+                    functools.partial(self._on_timeout, "during request"),
+                )
+            if (
+                self.request.method not in self._SUPPORTED_METHODS
+                and not self.request.allow_nonstandard_methods
+            ):
+                raise KeyError("unknown method %s" % self.request.method)
+            for key in (
+                "proxy_host",
+                "proxy_port",
+                "proxy_username",
+                "proxy_password",
+                "proxy_auth_mode",
+            ):
+                if getattr(self.request, key, None):
+                    raise NotImplementedError("%s not supported" % key)
+            if "Connection" not in self.request.headers:
+                self.request.headers["Connection"] = "close"
+            if "Host" not in self.request.headers:
+                if "@" in self.parsed.netloc:
+                    self.request.headers["Host"] = self.parsed.netloc.rpartition("@")[
+                        -1
+                    ]
+                else:
+                    self.request.headers["Host"] = self.parsed.netloc
+            username, password = None, None
+            if self.parsed.username is not None:
+                username, password = self.parsed.username, self.parsed.password
+            elif self.request.auth_username is not None:
+                username = self.request.auth_username
+                password = self.request.auth_password or ""
+            if username is not None:
+                assert password is not None
+                if self.request.auth_mode not in (None, "basic"):
+                    raise ValueError("unsupported auth_mode %s", self.request.auth_mode)
+                self.request.headers["Authorization"] = "Basic " + _unicode(
+                    base64.b64encode(
+                        httputil.encode_username_password(username, password)
                     )
-                    if (body_expected and not body_present) or (
-                        body_present and not body_expected
-                    ):
-                        raise ValueError(
-                            "Body must %sbe None for method %s (unless "
-                            "allow_nonstandard_methods is true)"
-                            % ("not " if body_expected else "", self.request.method)
-                        )
-                if self.request.expect_100_continue:
-                    self.request.headers["Expect"] = "100-continue"
-                if self.request.body is not None:
-                    # When body_producer is used the caller is responsible for
-                    # setting Content-Length (or else chunked encoding will be used).
-                    self.request.headers["Content-Length"] = str(len(self.request.body))
-                if (
-                    self.request.method == "POST"
-                    and "Content-Type" not in self.request.headers
-                ):
-                    self.request.headers[
-                        "Content-Type"
-                    ] = "application/x-www-form-urlencoded"
-                if self.request.decompress_response:
-                    self.request.headers["Accept-Encoding"] = "gzip"
-                req_path = (self.parsed.path or "/") + (
-                    ("?" + self.parsed.query) if self.parsed.query else ""
                 )
-                self.connection = self._create_connection(stream)
-                start_line = httputil.RequestStartLine(
-                    self.request.method, req_path, ""
+            if self.request.user_agent:
+                self.request.headers["User-Agent"] = self.request.user_agent
+            elif self.request.headers.get("User-Agent") is None:
+                self.request.headers["User-Agent"] = "Tornado/{}".format(version)
+            if not self.request.allow_nonstandard_methods:
+                # Some HTTP methods nearly always have bodies while others
+                # almost never do. Fail in this case unless the user has
+                # opted out of sanity checks with allow_nonstandard_methods.
+                body_expected = self.request.method in ("POST", "PATCH", "PUT")
+                body_present = (
+                    self.request.body is not None
+                    or self.request.body_producer is not None
                 )
-                self.connection.write_headers(start_line, self.request.headers)
-                if self.request.expect_100_continue:
-                    await self.connection.read_response(self)
-                else:
-                    await self._write_body(True)
+                if (body_expected and not body_present) or (
+                    body_present and not body_expected
+                ):
+                    raise ValueError(
+                        "Body must %sbe None for method %s (unless "
+                        "allow_nonstandard_methods is true)"
+                        % ("not " if body_expected else "", self.request.method)
+                    )
+            if self.request.expect_100_continue:
+                self.request.headers["Expect"] = "100-continue"
+            if self.request.body is not None:
+                # When body_producer is used the caller is responsible for
+                # setting Content-Length (or else chunked encoding will be used).
+                self.request.headers["Content-Length"] = str(len(self.request.body))
+            if (
+                self.request.method == "POST"
+                and "Content-Type" not in self.request.headers
+            ):
+                self.request.headers[
+                    "Content-Type"
+                ] = "application/x-www-form-urlencoded"
+            if self.request.decompress_response:
+                self.request.headers["Accept-Encoding"] = "gzip"
+            req_path = (self.parsed.path or "/") + (
+                ("?" + self.parsed.query) if self.parsed.query else ""
+            )
+            self.connection = self._create_connection(stream)
+            start_line = httputil.RequestStartLine(self.request.method, req_path, "")
+            self.connection.write_headers(start_line, self.request.headers)
+            if self.request.expect_100_continue:
+                await self.connection.read_response(self)
+            else:
+                await self._write_body(True)
         except Exception:
             if not self._handle_exception(*sys.exc_info()):
                 raise