]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add RequestHandler.add_header to allow headers to occur multiple times.
authorBen Darnell <ben@bendarnell.com>
Sat, 6 Aug 2011 20:06:54 +0000 (13:06 -0700)
committerBen Darnell <ben@bendarnell.com>
Sat, 6 Aug 2011 20:06:54 +0000 (13:06 -0700)
tornado/test/web_test.py
tornado/web.py

index 137516fb99023461f1167ad5f9d98a13df903e2f..61e938343e16db14f8924351c4d0c2ac83e29b87 100644 (file)
@@ -313,6 +313,13 @@ class FlowControlHandler(RequestHandler):
         self.write("3")
         self.finish()
 
+class MultiHeaderHandler(RequestHandler):
+    def get(self):
+        self.set_header("x-overwrite", "1")
+        self.set_header("x-overwrite", 2)
+        self.add_header("x-multi", 3)
+        self.add_header("x-multi", "4")
+
 class WebTest(AsyncHTTPTestCase, LogTrapTestCase):
     def get_app(self):
         loader = DictLoader({
@@ -335,6 +342,7 @@ class WebTest(AsyncHTTPTestCase, LogTrapTestCase):
             url("/uimodule_resources", UIModuleResourceHandler),
             url("/optional_path/(.+)?", OptionalPathHandler),
             url("/flow_control", FlowControlHandler),
+            url("/multi_header", MultiHeaderHandler),
             ]
         return Application(urls,
                            template_loader=loader,
@@ -415,6 +423,11 @@ js_embed()
     def test_flow_control(self):
         self.assertEqual(self.fetch("/flow_control").body, b("123"))
 
+    def test_multi_header(self):
+        response = self.fetch("/multi_header")
+        self.assertEqual(response.headers["x-overwrite"], "2")
+        self.assertEqual(response.headers.get_list("x-multi"), ["3", "4"])
+
 
 class ErrorResponseTest(AsyncHTTPTestCase, LogTrapTestCase):
     def get_app(self):
index 41f34e97519d151230a5bd897a0ef7a73d7127ed..4ea6dc79128666d2eb4b764a284eb4032b86a5a9 100644 (file)
@@ -62,6 +62,7 @@ import gzip
 import hashlib
 import hmac
 import httplib
+import itertools
 import logging
 import mimetypes
 import os.path
@@ -188,10 +189,16 @@ class RequestHandler(object):
 
     def clear(self):
         """Resets all headers and content for this response."""
+        # The performance cost of tornado.httputil.HTTPHeaders is significant
+        # (slowing down a benchmark with a trivial handler by more than 10%),
+        # and its case-normalization is not generally necessary for 
+        # headers we generate on the server side, so use a plain dict
+        # and list instead.
         self._headers = {
             "Server": "TornadoServer/%s" % tornado.version,
             "Content-Type": "text/html; charset=UTF-8",
         }
+        self._list_headers = []
         self.set_default_headers()
         if not self.request.supports_http_1_1():
             if self.request.headers.get("Connection") == "Keep-Alive":
@@ -225,6 +232,17 @@ class RequestHandler(object):
         HTTP specification. If the value is not a string, we convert it to
         a string. All header values are then encoded as UTF-8.
         """
+        self._headers[name] = self._convert_header_value(value)
+
+    def add_header(self, name, value):
+        """Adds the given response header and value.
+
+        Unlike `set_header`, `add_header` may be called multiple times
+        to return multiple values for the same header.
+        """
+        self._list_headers.append((name, self._convert_header_value(value)))
+
+    def _convert_header_value(self, value):
         if isinstance(value, (unicode, bytes_type)):
             value = utf8(value)
             # If \n is allowed into the header, it is possible to inject
@@ -240,7 +258,8 @@ class RequestHandler(object):
             value = str(value)
         else:
             raise TypeError("Unsupported header value %r" % value)
-        self._headers[name] = value
+        return value
+
 
     _ARG_DEFAULT = []
     def get_argument(self, name, default=_ARG_DEFAULT, strip=True):
@@ -992,7 +1011,8 @@ class RequestHandler(object):
         lines = [utf8(self.request.version + " " +
                       str(self._status_code) +
                       " " + httplib.responses[self._status_code])]
-        lines.extend([(utf8(n) + b(": ") + utf8(v)) for n, v in self._headers.iteritems()])
+        lines.extend([(utf8(n) + b(": ") + utf8(v)) for n, v in 
+                      itertools.chain(self._headers.iteritems(), self._list_headers)])
         for cookie_dict in getattr(self, "_new_cookies", []):
             for cookie in cookie_dict.values():
                 lines.append(utf8("Set-Cookie: " + cookie.OutputString(None)))