]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Consolidate the various HTTP header dictionary classes into one,
authorBen Darnell <bdarnell@beaker.local>
Fri, 9 Jul 2010 00:06:03 +0000 (17:06 -0700)
committerBen Darnell <bdarnell@beaker.local>
Fri, 9 Jul 2010 00:14:20 +0000 (17:14 -0700)
which includes better handling of headers with repeated values
(e.g. Set-Cookie)

tornado/httpclient.py
tornado/httpserver.py
tornado/httputil.py [new file with mode: 0755]
tornado/wsgi.py

index 4d29da66f5c7abbb2ffde27fb28575edb1c2448e..842950db0695ae4e8e1c1f43035c31373807d875 100644 (file)
@@ -24,6 +24,7 @@ import errno
 import escape
 import functools
 import httplib
+import httputil
 import ioloop
 import logging
 import pycurl
@@ -60,7 +61,7 @@ class HTTPClient(object):
         if not isinstance(request, HTTPRequest):
            request = HTTPRequest(url=request, **kwargs)
         buffer = cStringIO.StringIO()
-        headers = {}
+        headers = httputil.HTTPHeaders()
         try:
             _curl_setup_request(self._curl, request, buffer, headers)
             self._curl.perform()
@@ -254,7 +255,7 @@ class AsyncHTTPClient(object):
                 curl = self._free_list.pop()
                 (request, callback) = self._requests.popleft()
                 curl.info = {
-                    "headers": {},
+                    "headers": httputil.HTTPHeaders(),
                     "buffer": cStringIO.StringIO(),
                     "request": request,
                     "callback": callback,
@@ -462,7 +463,7 @@ class AsyncHTTPClient2(object):
                 curl = self._free_list.pop()
                 (request, callback) = self._requests.popleft()
                 curl.info = {
-                    "headers": {},
+                    "headers": httputil.HTTPHeaders(),
                     "buffer": cStringIO.StringIO(),
                     "request": request,
                     "callback": callback,
@@ -505,7 +506,7 @@ class AsyncHTTPClient2(object):
 
 
 class HTTPRequest(object):
-    def __init__(self, url, method="GET", headers={}, body=None,
+    def __init__(self, url, method="GET", headers=None, body=None,
                  auth_username=None, auth_password=None,
                  connect_timeout=20.0, request_timeout=20.0,
                  if_modified_since=None, follow_redirects=True,
@@ -513,6 +514,8 @@ class HTTPRequest(object):
                  network_interface=None, streaming_callback=None,
                  header_callback=None, prepare_curl_callback=None,
                  allow_nonstandard_methods=False):
+        if headers is None:
+            headers = httputil.HTTPHeaders()
         if if_modified_since:
             timestamp = calendar.timegm(if_modified_since.utctimetuple())
             headers["If-Modified-Since"] = email.utils.formatdate(
@@ -618,8 +621,13 @@ def _curl_create(max_simultaneous_connections=None):
 
 def _curl_setup_request(curl, request, buffer, headers):
     curl.setopt(pycurl.URL, request.url)
-    curl.setopt(pycurl.HTTPHEADER,
-                [_utf8("%s: %s" % i) for i in request.headers.iteritems()])
+    # Request headers may be either a regular dict or HTTPHeaders object
+    if isinstance(request.headers, httputil.HTTPHeaders):
+      curl.setopt(pycurl.HTTPHEADER,
+                  [_utf8("%s: %s" % i) for i in request.headers.get_all()])
+    else:
+        curl.setopt(pycurl.HTTPHEADER,
+                    [_utf8("%s: %s" % i) for i in request.headers.iteritems()])
     if request.header_callback:
         curl.setopt(pycurl.HEADERFUNCTION, request.header_callback)
     else:
@@ -695,17 +703,7 @@ def _curl_header_callback(headers, header_line):
         return
     if header_line == "\r\n":
         return
-    parts = header_line.split(":", 1)
-    if len(parts) != 2:
-        logging.warning("Invalid HTTP response header line %r", header_line)
-        return
-    name = parts[0].strip()
-    value = parts[1].strip()
-    if name in headers:
-        headers[name] = headers[name] + ',' + value
-    else:
-        headers[name] = value
-
+    headers.parse_line(header_line)
 
 def _curl_debug(debug_type, debug_msg):
     debug_types = ('I', '<', '>', '<', '>')
index 63131070fdf4d9b46fa102631159efcc0d169f4c..ad7ab077aec6124396921d0a0298f0cd6928bf9f 100644 (file)
@@ -19,6 +19,7 @@
 import cgi
 import errno
 import functools
+import httputil
 import ioloop
 import iostream
 import logging
@@ -277,7 +278,7 @@ class HTTPConnection(object):
         method, uri, version = start_line.split(" ")
         if not version.startswith("HTTP/"):
             raise Exception("Malformed HTTP version in HTTP Request-Line")
-        headers = HTTPHeaders.parse(data[eol:])
+        headers = httputil.HTTPHeaders.parse(data[eol:])
         self._request = HTTPRequest(
             connection=self, method=method, uri=uri, version=version,
             headers=headers, remote_ip=self.address[0])
@@ -332,7 +333,7 @@ class HTTPConnection(object):
             if eoh == -1:
                 logging.warning("multipart/form-data missing headers")
                 continue
-            headers = HTTPHeaders.parse(part[:eoh])
+            headers = httputil.HTTPHeaders.parse(part[:eoh])
             name_header = headers.get("Content-Disposition", "")
             if not name_header.startswith("form-data;") or \
                not part.endswith("\r\n"):
@@ -380,7 +381,7 @@ class HTTPRequest(object):
         self.method = method
         self.uri = uri
         self.version = version
-        self.headers = headers or HTTPHeaders()
+        self.headers = headers or httputil.HTTPHeaders()
         self.body = body or ""
         if connection and connection.xheaders:
             # Squid uses X-Forwarded-For, others use X-Real-Ip
@@ -437,23 +438,3 @@ class HTTPRequest(object):
         return "%s(%s, headers=%s)" % (
             self.__class__.__name__, args, dict(self.headers))
 
-
-class HTTPHeaders(dict):
-    """A dictionary that maintains Http-Header-Case for all keys."""
-    def __setitem__(self, name, value):
-        dict.__setitem__(self, self._normalize_name(name), value)
-
-    def __getitem__(self, name):
-        return dict.__getitem__(self, self._normalize_name(name))
-
-    def _normalize_name(self, name):
-        return "-".join([w.capitalize() for w in name.split("-")])
-
-    @classmethod
-    def parse(cls, headers_string):
-        headers = cls()
-        for line in headers_string.splitlines():
-            if line:
-                name, value = line.split(":", 1)
-                headers[name] = value.strip()
-        return headers
diff --git a/tornado/httputil.py b/tornado/httputil.py
new file mode 100755 (executable)
index 0000000..5e563e8
--- /dev/null
@@ -0,0 +1,140 @@
+#!/usr/bin/env python
+#
+# Copyright 2009 Facebook
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""HTTP utility code shared by clients and servers."""
+
+class HTTPHeaders(dict):
+    """A dictionary that maintains Http-Header-Case for all keys.
+
+    Supports multiple values per key via a pair of new methods,
+    add() and get_list().  The regular dictionary interface returns a single
+    value per key, with multiple values joined by a comma.
+
+    >>> h = HTTPHeaders({"content-type": "text/html"})
+    >>> h.keys()
+    ['Content-Type']
+    >>> h["Content-Type"]
+    'text/html'
+
+    >>> h.add("Set-Cookie", "A=B")
+    >>> h.add("Set-Cookie", "C=D")
+    >>> h["set-cookie"]
+    'A=B,C=D'
+    >>> h.get_list("set-cookie")
+    ['A=B', 'C=D']
+
+    >>> for (k,v) in sorted(h.get_all()):
+    ...    print '%s: %s' % (k,v)
+    ...
+    Content-Type: text/html
+    Set-Cookie: A=B
+    Set-Cookie: C=D
+    """
+    def __init__(self, *args, **kwargs):
+        # Don't pass args or kwargs to dict.__init__, as it will bypass
+        # our __setitem__
+        dict.__init__(self)
+        self._as_list = {}
+        self.update(*args, **kwargs)
+
+    # new public methods
+
+    def add(self, name, value):
+        """Adds a new value for the given key."""
+        norm_name = HTTPHeaders._normalize_name(name)
+        if norm_name in self:
+            # bypass our override of __setitem__ since it modifies _as_list
+            dict.__setitem__(self, norm_name, self[norm_name] + ',' + value)
+            self._as_list[norm_name].append(value)
+        else:
+            self[norm_name] = value
+
+    def get_list(self, name):
+        """Returns all values for the given header as a list."""
+        norm_name = HTTPHeaders._normalize_name(name)
+        return self._as_list.get(norm_name, [])
+
+    def get_all(self):
+        """Returns an iterable of all (name, value) pairs.
+
+        If a header has multiple values, multiple pairs will be
+        returned with the same name.
+        """
+        for name, list in self._as_list.iteritems():
+            for value in list:
+                yield (name, value)
+
+    def parse_line(self, line):
+        """Updates the dictionary with a single header line.
+
+        >>> h = HTTPHeaders()
+        >>> h.parse_line("Content-Type: text/html")
+        >>> h.get('content-type')
+        'text/html'
+        """
+        name, value = line.split(":", 1)
+        self.add(name, value.strip())
+
+    @classmethod
+    def parse(cls, headers):
+        """Returns a dictionary from HTTP header text.
+
+        >>> h = HTTPHeaders.parse("Content-Type: text/html\\r\\nContent-Length: 42\\r\\n")
+        >>> sorted(h.iteritems())
+        [('Content-Length', '42'), ('Content-Type', 'text/html')]
+        """
+        h = cls()
+        for line in headers.splitlines():
+            if line:
+                h.parse_line(line)
+        return h
+
+    # dict implementation overrides
+
+    def __setitem__(self, name, value):
+        norm_name = HTTPHeaders._normalize_name(name)
+        dict.__setitem__(self, norm_name, value)
+        self._as_list[norm_name] = [value]
+
+    def __getitem__(self, name):
+        return dict.__getitem__(self, HTTPHeaders._normalize_name(name))
+
+    def __delitem__(self, name):
+        norm_name = HTTPHeaders._normalize_name(name)
+        dict.__delitem__(self, norm_name)
+        del self._as_list[norm_name]
+
+    def get(self, name, default=None):
+        return dict.get(self, HTTPHeaders._normalize_name(name), default)
+
+    def update(self, *args, **kwargs):
+        # dict.update bypasses our __setitem__
+        for k, v in dict(*args, **kwargs).iteritems():
+            self[k] = v
+
+    @staticmethod
+    def _normalize_name(name):
+        """Converts a name to Http-Header-Case.
+
+        >>> HTTPHeaders._normalize_name("coNtent-TYPE")
+        'Content-Type'
+        """
+        return "-".join([w.capitalize() for w in name.split("-")])
+
+
+if __name__ == "__main__":
+    import doctest
+    doctest.testmod()
index 4aaa5fb73490f0190de0583ace51f02ad252d90f..de35669673af938e4629e01f245bd1e2bc329333 100644 (file)
@@ -54,6 +54,7 @@ import cgi
 import cStringIO
 import escape
 import httplib
+import httputil
 import logging
 import sys
 import time
@@ -100,7 +101,7 @@ class HTTPRequest(object):
                 values = [v for v in values if v]
                 if values: self.arguments[name] = values
         self.version = "HTTP/1.1"
-        self.headers = HTTPHeaders()
+        self.headers = httputil.HTTPHeaders()
         if environ.get("CONTENT_TYPE"):
             self.headers["Content-Type"] = environ["CONTENT_TYPE"]
         if environ.get("CONTENT_LENGTH"):
@@ -164,7 +165,7 @@ class HTTPRequest(object):
             if eoh == -1:
                 logging.warning("multipart/form-data missing headers")
                 continue
-            headers = HTTPHeaders.parse(part[:eoh])
+            headers = httputil.HTTPHeaders.parse(part[:eoh])
             name_header = headers.get("Content-Disposition", "")
             if not name_header.startswith("form-data;") or \
                not part.endswith("\r\n"):
@@ -293,23 +294,3 @@ class WSGIContainer(object):
             request.remote_ip + ")"
         log_method("%d %s %.2fms", status_code, summary, request_time)
 
-
-class HTTPHeaders(dict):
-    """A dictionary that maintains Http-Header-Case for all keys."""
-    def __setitem__(self, name, value):
-        dict.__setitem__(self, self._normalize_name(name), value)
-
-    def __getitem__(self, name):
-        return dict.__getitem__(self, self._normalize_name(name))
-
-    def _normalize_name(self, name):
-        return "-".join([w.capitalize() for w in name.split("-")])
-
-    @classmethod
-    def parse(cls, headers_string):
-        headers = cls()
-        for line in headers_string.splitlines():
-            if line:
-                name, value = line.split(":", 1)
-                headers[name] = value.strip()
-        return headers