]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add client-side support for Expect: 100-continue.
authorBen Darnell <ben@bendarnell.com>
Sat, 26 Apr 2014 21:07:41 +0000 (17:07 -0400)
committerBen Darnell <ben@bendarnell.com>
Sat, 26 Apr 2014 21:07:41 +0000 (17:07 -0400)
Don't send the 100-continue response if the header callback already sent
a response.

tornado/http1connection.py
tornado/httpclient.py
tornado/simple_httpclient.py
tornado/test/simple_httpclient_test.py

index 69e8328b71016eaf9277c64cd494b484ad2b5436..af8dd83a01dbbfdd9c0b1cfdba7ab83229682639 100644 (file)
@@ -139,7 +139,8 @@ class HTTP1Connection(object):
                     # in the case of a 100-continue.  Document or change?
                     yield self._read_message(delegate)
             else:
-                if headers.get("Expect") == "100-continue":
+                if (headers.get("Expect") == "100-continue" and
+                    not self._write_finished):
                     self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
             if not skip_body:
                 body_future = self._read_body(headers, delegate)
index ac3d92b3e91f29fb43cf208e947894271711f27a..5a57b3f57193989ce315de667ffc17e9a2020558 100644 (file)
@@ -259,7 +259,8 @@ class HTTPRequest(object):
                  proxy_password=None, allow_nonstandard_methods=None,
                  validate_cert=None, ca_certs=None,
                  allow_ipv6=None,
-                 client_key=None, client_cert=None, body_producer=None):
+                 client_key=None, client_cert=None, body_producer=None,
+                 expect_100_continue=False):
         r"""All parameters except ``url`` are optional.
 
         :arg string url: URL to fetch
@@ -328,6 +329,11 @@ class HTTPRequest(object):
            note below when used with ``curl_httpclient``.
         :arg string client_cert: Filename for client SSL certificate, if any.
            See note below when used with ``curl_httpclient``.
+        :arg bool expect_100_continue: If true, send the
+           ``Expect: 100-continue`` header and wait for a continue response
+           before sending the request body.  Only supported with
+           simple_httpclient.
+
 
         .. note::
 
@@ -377,6 +383,7 @@ class HTTPRequest(object):
         self.allow_ipv6 = allow_ipv6
         self.client_key = client_key
         self.client_cert = client_cert
+        self.expect_100_continue = expect_100_continue
         self.start_time = time.time()
 
     @property
index 417498eb7361c9ba9f82399dd7fe2b1a6d96163f..b22da62bfa1d03edc2f74cc6a0834d162e7fee95 100644 (file)
@@ -319,6 +319,8 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
                     raise AssertionError(
                         'Body must be empty for "%s" request'
                         % self.request.method)
+        if self.request.expect_100_continue:
+            self.request.headers["Expect"] = "100-continue"
         if self.request.body is not None:
             # When body_producer is used the caller is responsible for
             # setting Content-Length (or else chunked encoding will be used).
@@ -345,6 +347,12 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
             start_line, self.request.headers,
             has_body=(self.request.body is not None or
                       self.request.body_producer is not None))
+        if self.request.expect_100_continue:
+            self._read_response()
+        else:
+            self._write_body(True)
+
+    def _write_body(self, start_read):
         if self.request.body is not None:
             self.connection.write(self.request.body)
             self.connection.finish()
@@ -354,11 +362,14 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
                 def on_body_written(fut):
                     fut.result()
                     self.connection.finish()
-                    self._read_response()
+                    if start_read:
+                        self._read_response()
                 self.io_loop.add_future(fut, on_body_written)
                 return
             self.connection.finish()
-        self._read_response()
+        if start_read:
+            self._read_response()
+
 
     def _read_response(self):
         # Ensure that any exception raised in read_response ends up in our
@@ -410,6 +421,9 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
             raise HTTPError(599, message)
 
     def headers_received(self, first_line, headers):
+        if self.request.expect_100_continue and first_line.code == 100:
+            self._write_body(False)
+            return
         self.headers = headers
         self.code = first_line.code
         self.reason = first_line.reason
index 5ca7955998fc0ee70f1de5b447c8f396f8285a01..4d852bb289dbe3552a2def2b617230b1e58d48e7 100644 (file)
@@ -21,7 +21,7 @@ from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWo
 from tornado.test import httpclient_test
 from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog
 from tornado.test.util import unittest, skipOnTravis
-from tornado.web import RequestHandler, Application, asynchronous, url
+from tornado.web import RequestHandler, Application, asynchronous, url, stream_request_body
 
 
 class SimpleHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
@@ -113,6 +113,13 @@ class EchoPostHandler(RequestHandler):
         self.write(self.request.body)
 
 
+@stream_request_body
+class RespondInPrepareHandler(RequestHandler):
+    def prepare(self):
+        self.set_status(403)
+        self.finish("forbidden")
+
+
 class SimpleHTTPClientTestMixin(object):
     def get_app(self):
         # callable objects to finish pending /trigger requests
@@ -133,6 +140,7 @@ class SimpleHTTPClientTestMixin(object):
             url("/host_echo", HostEchoHandler),
             url("/no_content_length", NoContentLengthHandler),
             url("/echo_post", EchoPostHandler),
+            url("/respond_in_prepare", RespondInPrepareHandler),
         ], gzip=True)
 
     def test_singleton(self):
@@ -376,6 +384,20 @@ class SimpleHTTPClientTestMixin(object):
         response.rethrow()
         self.assertEqual(response.body, b"12345678")
 
+    def test_100_continue(self):
+        response = self.fetch("/echo_post", method="POST",
+                              body=b"1234",
+                              expect_100_continue=True)
+        self.assertEqual(response.body, b"1234")
+
+    def test_100_continue_early_response(self):
+        def body_producer(write):
+            raise Exception("should not be called")
+        response = self.fetch("/respond_in_prepare", method="POST",
+                              body_producer=body_producer,
+                              expect_100_continue=True)
+        self.assertEqual(response.code, 403)
+
 
 class SimpleHTTPClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPTestCase):
     def setUp(self):