]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
simple_httpclient: Initial refactoring into coroutines
authorBen Darnell <ben@bendarnell.com>
Fri, 27 Apr 2018 15:52:51 +0000 (11:52 -0400)
committerBen Darnell <ben@bendarnell.com>
Fri, 27 Apr 2018 16:57:48 +0000 (12:57 -0400)
Eliminates the use of ExceptionStackContext.

tornado/simple_httpclient.py

index 74cceaaca916022b35f4f489bb80410e75efdf3c..4df4898a9705404aa6782fc5c0288975309912fa 100644 (file)
@@ -230,7 +230,11 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
         # Timeout handle returned by IOLoop.add_timeout
         self._timeout = None
         self._sockaddr = None
-        with stack_context.ExceptionStackContext(self._handle_exception):
+        IOLoop.current().add_callback(self.run)
+
+    @gen.coroutine
+    def run(self):
+        try:
             self.parsed = urlparse.urlsplit(_unicode(self.request.url))
             if self.parsed.scheme not in ("http", "https"):
                 raise ValueError("Unsupported url scheme: %s" %
@@ -248,7 +252,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
                 host = host[1:-1]
             self.parsed_hostname = host  # save final host for _on_connect
 
-            if request.allow_ipv6 is False:
+            if self.request.allow_ipv6 is False:
                 af = socket.AF_INET
             else:
                 af = socket.AF_UNSPEC
@@ -260,92 +264,93 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
                 self._timeout = self.io_loop.add_timeout(
                     self.start_time + timeout,
                     stack_context.wrap(functools.partial(self._on_timeout, "while connecting")))
-            fut = self.tcp_client.connect(host, port, af=af,
-                                          ssl_options=ssl_options,
-                                          max_buffer_size=self.max_buffer_size)
-            fut.add_done_callback(stack_context.wrap(self._on_connect))
-
-    def _on_connect(self, stream_fut):
-        stream = stream_fut.result()
-        if self.final_callback is None:
-            # final_callback is cleared if we've hit our timeout.
-            stream.close()
-            return
-        self.stream = stream
-        self.stream.set_close_callback(self.on_connection_close)
-        self._remove_timeout()
-        if self.final_callback is None:
-            return
-        if self.request.request_timeout:
-            self._timeout = self.io_loop.add_timeout(
-                self.start_time + self.request.request_timeout,
-                stack_context.wrap(functools.partial(self._on_timeout, "during request")))
-        if (self.request.method not in self._SUPPORTED_METHODS and
-                not self.request.allow_nonstandard_methods):
-            raise KeyError("unknown method %s" % self.request.method)
-        for key in ('network_interface',
-                    'proxy_host', 'proxy_port',
-                    'proxy_username', 'proxy_password',
-                    'proxy_auth_mode'):
-            if getattr(self.request, key, None):
-                raise NotImplementedError('%s not supported' % key)
-        if "Connection" not in self.request.headers:
-            self.request.headers["Connection"] = "close"
-        if "Host" not in self.request.headers:
-            if '@' in self.parsed.netloc:
-                self.request.headers["Host"] = self.parsed.netloc.rpartition('@')[-1]
-            else:
-                self.request.headers["Host"] = self.parsed.netloc
-        username, password = None, None
-        if self.parsed.username is not None:
-            username, password = self.parsed.username, self.parsed.password
-        elif self.request.auth_username is not None:
-            username = self.request.auth_username
-            password = self.request.auth_password or ''
-        if username is not None:
-            if self.request.auth_mode not in (None, "basic"):
-                raise ValueError("unsupported auth_mode %s",
-                                 self.request.auth_mode)
-            auth = utf8(username) + b":" + utf8(password)
-            self.request.headers["Authorization"] = (b"Basic " +
-                                                     base64.b64encode(auth))
-        if self.request.user_agent:
-            self.request.headers["User-Agent"] = self.request.user_agent
-        if not self.request.allow_nonstandard_methods:
-            # Some HTTP methods nearly always have bodies while others
-            # almost never do. Fail in this case unless the user has
-            # opted out of sanity checks with allow_nonstandard_methods.
-            body_expected = self.request.method in ("POST", "PATCH", "PUT")
-            body_present = (self.request.body is not None or
-                            self.request.body_producer is not None)
-            if ((body_expected and not body_present) or
-                    (body_present and not body_expected)):
-                raise ValueError(
-                    'Body must %sbe None for method %s (unless '
-                    'allow_nonstandard_methods is true)' %
-                    ('not ' if body_expected else '', self.request.method))
-        if self.request.expect_100_continue:
-            self.request.headers["Expect"] = "100-continue"
-        if self.request.body is not None:
-            # When body_producer is used the caller is responsible for
-            # setting Content-Length (or else chunked encoding will be used).
-            self.request.headers["Content-Length"] = str(len(
-                self.request.body))
-        if (self.request.method == "POST" and
-                "Content-Type" not in self.request.headers):
-            self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
-        if self.request.decompress_response:
-            self.request.headers["Accept-Encoding"] = "gzip"
-        req_path = ((self.parsed.path or '/') +
-                    (('?' + self.parsed.query) if self.parsed.query else ''))
-        self.connection = self._create_connection(stream)
-        start_line = httputil.RequestStartLine(self.request.method,
-                                               req_path, '')
-        self.connection.write_headers(start_line, self.request.headers)
-        if self.request.expect_100_continue:
-            self._read_response()
-        else:
-            self._write_body(True)
+                stream = yield self.tcp_client.connect(
+                    host, port, af=af,
+                    ssl_options=ssl_options,
+                    max_buffer_size=self.max_buffer_size)
+
+                if self.final_callback is None:
+                    # final_callback is cleared if we've hit our timeout.
+                    stream.close()
+                    return
+                self.stream = stream
+                self.stream.set_close_callback(self.on_connection_close)
+                self._remove_timeout()
+                if self.final_callback is None:
+                    return
+                if self.request.request_timeout:
+                    self._timeout = self.io_loop.add_timeout(
+                        self.start_time + self.request.request_timeout,
+                        stack_context.wrap(functools.partial(self._on_timeout, "during request")))
+                if (self.request.method not in self._SUPPORTED_METHODS and
+                        not self.request.allow_nonstandard_methods):
+                    raise KeyError("unknown method %s" % self.request.method)
+                for key in ('network_interface',
+                            'proxy_host', 'proxy_port',
+                            'proxy_username', 'proxy_password',
+                            'proxy_auth_mode'):
+                    if getattr(self.request, key, None):
+                        raise NotImplementedError('%s not supported' % key)
+                if "Connection" not in self.request.headers:
+                    self.request.headers["Connection"] = "close"
+                if "Host" not in self.request.headers:
+                    if '@' in self.parsed.netloc:
+                        self.request.headers["Host"] = self.parsed.netloc.rpartition('@')[-1]
+                    else:
+                        self.request.headers["Host"] = self.parsed.netloc
+                username, password = None, None
+                if self.parsed.username is not None:
+                    username, password = self.parsed.username, self.parsed.password
+                elif self.request.auth_username is not None:
+                    username = self.request.auth_username
+                    password = self.request.auth_password or ''
+                if username is not None:
+                    if self.request.auth_mode not in (None, "basic"):
+                        raise ValueError("unsupported auth_mode %s",
+                                         self.request.auth_mode)
+                    auth = utf8(username) + b":" + utf8(password)
+                    self.request.headers["Authorization"] = (b"Basic " +
+                                                             base64.b64encode(auth))
+                if self.request.user_agent:
+                    self.request.headers["User-Agent"] = self.request.user_agent
+                if not self.request.allow_nonstandard_methods:
+                    # Some HTTP methods nearly always have bodies while others
+                    # almost never do. Fail in this case unless the user has
+                    # opted out of sanity checks with allow_nonstandard_methods.
+                    body_expected = self.request.method in ("POST", "PATCH", "PUT")
+                    body_present = (self.request.body is not None or
+                                    self.request.body_producer is not None)
+                    if ((body_expected and not body_present) or
+                            (body_present and not body_expected)):
+                        raise ValueError(
+                            'Body must %sbe None for method %s (unless '
+                            'allow_nonstandard_methods is true)' %
+                            ('not ' if body_expected else '', self.request.method))
+                if self.request.expect_100_continue:
+                    self.request.headers["Expect"] = "100-continue"
+                if self.request.body is not None:
+                    # When body_producer is used the caller is responsible for
+                    # setting Content-Length (or else chunked encoding will be used).
+                    self.request.headers["Content-Length"] = str(len(
+                        self.request.body))
+                if (self.request.method == "POST" and
+                        "Content-Type" not in self.request.headers):
+                    self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
+                if self.request.decompress_response:
+                    self.request.headers["Accept-Encoding"] = "gzip"
+                req_path = ((self.parsed.path or '/') +
+                            (('?' + self.parsed.query) if self.parsed.query else ''))
+                self.connection = self._create_connection(stream)
+                start_line = httputil.RequestStartLine(self.request.method,
+                                                       req_path, '')
+                self.connection.write_headers(start_line, self.request.headers)
+                if self.request.expect_100_continue:
+                    yield self.connection.read_response(self)
+                else:
+                    yield self._write_body(True)
+        except Exception:
+            if not self._handle_exception(*sys.exc_info()):
+                raise
 
     def _get_ssl_options(self, scheme):
         if scheme == "https":
@@ -383,7 +388,8 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
         self._timeout = None
         error_message = "Timeout {0}".format(info) if info else "Timeout"
         if self.final_callback is not None:
-            raise HTTPTimeoutError(error_message)
+            self._handle_exception(HTTPTimeoutError, HTTPTimeoutError(error_message),
+                                   None)
 
     def _remove_timeout(self):
         if self._timeout is not None:
@@ -402,31 +408,21 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
             self._sockaddr)
         return connection
 
+    @gen.coroutine
     def _write_body(self, start_read):
         if self.request.body is not None:
             self.connection.write(self.request.body)
         elif self.request.body_producer is not None:
             fut = self.request.body_producer(self.connection.write)
             if fut is not None:
-                fut = gen.convert_yielded(fut)
-
-                def on_body_written(fut):
-                    fut.result()
-                    self.connection.finish()
-                    if start_read:
-                        self._read_response()
-                self.io_loop.add_future(fut, on_body_written)
-                return
+                yield fut
         self.connection.finish()
         if start_read:
-            self._read_response()
-
-    def _read_response(self):
-        # Ensure that any exception raised in read_response ends up in our
-        # stack context.
-        self.io_loop.add_future(
-            self.connection.read_response(self),
-            lambda f: f.result())
+            try:
+                yield self.connection.read_response(self)
+            except StreamClosedError:
+                if not self._handle_exception(*sys.exc_info()):
+                    raise
 
     def _release(self):
         if self.release_callback is not None: