from tornado.log import app_log, gen_log
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.template import DictLoader
-from tornado.testing import AsyncHTTPTestCase, ExpectLog
+from tornado.testing import AsyncHTTPTestCase, ExpectLog, gen_test
from tornado.test.util import unittest
from tornado.util import u, bytes_type, ObjectDict, unicode_type
- from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature, create_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body
-from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError
++from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body
import binascii
+import contextlib
import datetime
import email.utils
import logging
self.assertEqual(resp.body, b'hello')
+class StreamingRequestBodyTest(WebTestCase):
+ def get_handlers(self):
+ @stream_request_body
+ class StreamingBodyHandler(RequestHandler):
+ def initialize(self, test):
+ self.test = test
+
+ def prepare(self):
+ self.test.prepared.set_result(None)
+
+ def data_received(self, data):
+ self.test.data.set_result(data)
+
+ def get(self):
+ self.test.finished.set_result(None)
+ self.write({})
+
+ @stream_request_body
+ class EarlyReturnHandler(RequestHandler):
+ def prepare(self):
+ # If we finish the response in prepare, it won't continue to
+ # the (non-existent) data_received.
+ raise HTTPError(401)
+
+ @stream_request_body
+ class CloseDetectionHandler(RequestHandler):
+ def initialize(self, test):
+ self.test = test
+
+ def on_connection_close(self):
+ super(CloseDetectionHandler, self).on_connection_close()
+ self.test.close_future.set_result(None)
+
+ return [('/stream_body', StreamingBodyHandler, dict(test=self)),
+ ('/early_return', EarlyReturnHandler),
+ ('/close_detection', CloseDetectionHandler, dict(test=self))]
+
+ def connect(self, url, connection_close):
+ # Use a raw connection so we can control the sending of data.
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+ s.connect(("localhost", self.get_http_port()))
+ stream = IOStream(s, io_loop=self.io_loop)
+ stream.write(b"GET " + url + b" HTTP/1.1\r\n")
+ if connection_close:
+ stream.write(b"Connection: close\r\n")
+ stream.write(b"Transfer-Encoding: chunked\r\n\r\n")
+ return stream
+
+ @gen_test
+ def test_streaming_body(self):
+ self.prepared = Future()
+ self.data = Future()
+ self.finished = Future()
+
+ stream = self.connect(b"/stream_body", connection_close=True)
+ yield self.prepared
+ stream.write(b"4\r\nasdf\r\n")
+ # Ensure the first chunk is received before we send the second.
+ data = yield self.data
+ self.assertEqual(data, b"asdf")
+ self.data = Future()
+ stream.write(b"4\r\nqwer\r\n")
+ data = yield self.data
+ self.assertEquals(data, b"qwer")
+ stream.write(b"0\r\n")
+ yield self.finished
+ data = yield gen.Task(stream.read_until_close)
+ # This would ideally use an HTTP1Connection to read the response.
+ self.assertTrue(data.endswith(b"{}"))
+ stream.close()
+
+ @gen_test
+ def test_early_return(self):
+ stream = self.connect(b"/early_return", connection_close=False)
+ data = yield gen.Task(stream.read_until_close)
+ self.assertTrue(data.startswith(b"HTTP/1.1 401"))
+
+ @gen_test
+ def test_early_return_with_data(self):
+ stream = self.connect(b"/early_return", connection_close=False)
+ stream.write(b"4\r\nasdf\r\n")
+ data = yield gen.Task(stream.read_until_close)
+ self.assertTrue(data.startswith(b"HTTP/1.1 401"))
+
+ @gen_test
+ def test_close_during_upload(self):
+ self.close_future = Future()
+ stream = self.connect(b"/close_detection", connection_close=False)
+ stream.close()
+ yield self.close_future
+
+
+class StreamingRequestFlowControlTest(WebTestCase):
+ def get_handlers(self):
+ from tornado.ioloop import IOLoop
+
+ # Each method in this handler returns a Future and yields to the
+ # IOLoop so the future is not immediately ready. Ensure that the
+ # Futures are respected and no method is called before the previous
+ # one has completed.
+ @stream_request_body
+ class FlowControlHandler(RequestHandler):
+ def initialize(self, test):
+ self.test = test
+ self.method = None
+ self.methods = []
+
+ @contextlib.contextmanager
+ def in_method(self, method):
+ if self.method is not None:
+ self.test.fail("entered method %s while in %s" %
+ (method, self.method))
+ self.method = method
+ self.methods.append(method)
+ try:
+ yield
+ finally:
+ self.method = None
+
+ @gen.coroutine
+ def prepare(self):
+ with self.in_method('prepare'):
+ yield gen.Task(IOLoop.current().add_callback)
+
+ @gen.coroutine
+ def data_received(self, data):
+ with self.in_method('data_received'):
+ yield gen.Task(IOLoop.current().add_callback)
+
+ @gen.coroutine
+ def post(self):
+ with self.in_method('post'):
+ yield gen.Task(IOLoop.current().add_callback)
+ self.write(dict(methods=self.methods))
+
+ return [('/', FlowControlHandler, dict(test=self))]
+
+ def get_httpserver_options(self):
+ # Use a small chunk size so flow control is relevant even though
+ # all the data arrives at once.
+ return dict(chunk_size=10)
+
+ def test_flow_control(self):
+ response = self.fetch('/', body='abcdefghijklmnopqrstuvwxyz',
+ method='POST')
+ response.rethrow()
+ self.assertEqual(json_decode(response.body),
+ dict(methods=['prepare', 'data_received',
+ 'data_received', 'data_received',
+ 'post']))
+
+
+@wsgi_safe
+class IncorrectContentLengthTest(SimpleHandlerTestCase):
+ def get_handlers(self):
+ test = self
+ self.server_error = None
+
+ # Manually set a content-length that doesn't match the actual content.
+ class TooHigh(RequestHandler):
+ def get(self):
+ self.set_header("Content-Length", "42")
+ try:
+ self.finish("ok")
+ except Exception as e:
+ test.server_error = e
+ raise
+
+ class TooLow(RequestHandler):
+ def get(self):
+ self.set_header("Content-Length", "2")
+ try:
+ self.finish("hello")
+ except Exception as e:
+ test.server_error = e
+ raise
+
+ return [('/high', TooHigh),
+ ('/low', TooLow)]
+
+ def test_content_length_too_high(self):
+ # When the content-length is too high, the connection is simply
+ # closed without completing the response. An error is logged on
+ # the server.
+ with ExpectLog(app_log, "Uncaught exception"):
+ with ExpectLog(gen_log,
+ "Cannot send error response after headers written"):
+ response = self.fetch("/high")
+ self.assertEqual(response.code, 599)
+ self.assertEqual(str(self.server_error),
+ "Tried to write 40 bytes less than Content-Length")
+
+ def test_content_length_too_low(self):
+ # When the content-length is too low, the connection is closed
+ # without writing the last chunk, so the client never sees the request
+ # complete (which would be a framing error).
+ with ExpectLog(app_log, "Uncaught exception"):
+ with ExpectLog(gen_log,
+ "Cannot send error response after headers written"):
+ response = self.fetch("/low")
+ self.assertEqual(response.code, 599)
+ self.assertEqual(str(self.server_error),
+ "Tried to write more data than Content-Length")
+
+
+class ClientCloseTest(SimpleHandlerTestCase):
+ class Handler(RequestHandler):
+ def get(self):
+ # Simulate a connection closed by the client during
+ # request processing. The client will see an error, but the
+ # server should respond gracefully (without logging errors
+ # because we were unable to write out as many bytes as
+ # Content-Length said we would)
+ self.request.connection.stream.close()
+ self.write('hello')
+
+ def test_client_close(self):
+ response = self.fetch('/')
+ self.assertEqual(response.code, 599)
++
++
+ class SignedValueTest(unittest.TestCase):
+ SECRET = "It's a secret to everybody"
+
+ def past(self):
+ return self.present() - 86400 * 32
+
+ def present(self):
+ return 1300000000
+
+ def test_known_values(self):
+ signed_v1 = create_signed_value(SignedValueTest.SECRET, "key", "value",
+ version=1, clock=self.present)
+ self.assertEqual(
+ signed_v1,
+ b"dmFsdWU=|1300000000|31c934969f53e48164c50768b40cbd7e2daaaa4f")
+
+ signed_v2 = create_signed_value(SignedValueTest.SECRET, "key", "value",
+ version=2, clock=self.present)
+ self.assertEqual(
+ signed_v2,
+ b"2|1:0|10:1300000000|3:key|8:dmFsdWU=|"
+ b"3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e152")
+
+ signed_default = create_signed_value(SignedValueTest.SECRET,
+ "key", "value", clock=self.present)
+ self.assertEqual(signed_default, signed_v2)
+
+ decoded_v1 = decode_signed_value(SignedValueTest.SECRET, "key",
+ signed_v1, min_version=1,
+ clock=self.present)
+ self.assertEqual(decoded_v1, b"value")
+
+ decoded_v2 = decode_signed_value(SignedValueTest.SECRET, "key",
+ signed_v2, min_version=2,
+ clock=self.present)
+ self.assertEqual(decoded_v2, b"value")
+
+ def test_name_swap(self):
+ signed1 = create_signed_value(SignedValueTest.SECRET, "key1", "value",
+ clock=self.present)
+ signed2 = create_signed_value(SignedValueTest.SECRET, "key2", "value",
+ clock=self.present)
+ # Try decoding each string with the other's "name"
+ decoded1 = decode_signed_value(SignedValueTest.SECRET, "key2", signed1,
+ clock=self.present)
+ self.assertIs(decoded1, None)
+ decoded2 = decode_signed_value(SignedValueTest.SECRET, "key1", signed2,
+ clock=self.present)
+ self.assertIs(decoded2, None)
+
+ def test_expired(self):
+ signed = create_signed_value(SignedValueTest.SECRET, "key1", "value",
+ clock=self.past)
+ decoded_past = decode_signed_value(SignedValueTest.SECRET, "key1",
+ signed, clock=self.past)
+ self.assertEqual(decoded_past, b"value")
+ decoded_present = decode_signed_value(SignedValueTest.SECRET, "key1",
+ signed, clock=self.present)
+ self.assertIs(decoded_present, None)
+
+ def test_payload_tampering(self):
+ # These cookies are variants of the one in test_known_values.
+ sig = "3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e152"
+ def validate(prefix):
+ return (b'value' ==
+ decode_signed_value(SignedValueTest.SECRET, "key",
+ prefix + sig, clock=self.present))
+ self.assertTrue(validate("2|1:0|10:1300000000|3:key|8:dmFsdWU=|"))
+ # Change key version
+ self.assertFalse(validate("2|1:1|10:1300000000|3:key|8:dmFsdWU=|"))
+ # length mismatch (field too short)
+ self.assertFalse(validate("2|1:0|10:130000000|3:key|8:dmFsdWU=|"))
+ # length mismatch (field too long)
+ self.assertFalse(validate("2|1:0|10:1300000000|3:keey|8:dmFsdWU=|"))
+
+ def test_signature_tampering(self):
+ prefix = "2|1:0|10:1300000000|3:key|8:dmFsdWU=|"
+ def validate(sig):
+ return (b'value' ==
+ decode_signed_value(SignedValueTest.SECRET, "key",
+ prefix + sig, clock=self.present))
+ self.assertTrue(validate(
+ "3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e152"))
+ # All zeros
+ self.assertFalse(validate("0" * 32))
+ # Change one character
+ self.assertFalse(validate(
+ "4d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e152"))
+ # Change another character
+ self.assertFalse(validate(
+ "3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e153"))
+ # Truncate
+ self.assertFalse(validate(
+ "3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e15"))
+ # Lengthen
+ self.assertFalse(validate(
+ "3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e1538"))
+
+ def test_non_ascii(self):
+ value = b"\xe9"
+ signed = create_signed_value(SignedValueTest.SECRET, "key", value,
+ clock=self.present)
+ decoded = decode_signed_value(SignedValueTest.SECRET, "key", signed,
+ clock=self.present)
+ self.assertEqual(value, decoded)