]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Make HTTPHeaders a subclass of MutableMapping ABC instead of dict.
authorBen Darnell <ben@bendarnell.com>
Sun, 13 Sep 2015 17:15:06 +0000 (13:15 -0400)
committerBen Darnell <ben@bendarnell.com>
Sun, 13 Sep 2015 17:20:58 +0000 (13:20 -0400)
This simplifies the implementation since MutableMapping is designed
for subclassing while dict has many special cases that need to be
overridden. In particular, this change fixes the setdefault()
method.

Fixes #1500.

tornado/httputil.py
tornado/test/httputil_test.py

index 747dfc400c98298df98c25e3e6b02860b0de2828..471df54f967c8792ce5aae9d058533db4ebc285a 100644 (file)
@@ -98,7 +98,7 @@ class _NormalizedHeaderCache(dict):
 _normalized_headers = _NormalizedHeaderCache(1000)
 
 
-class HTTPHeaders(dict):
+class HTTPHeaders(collections.MutableMapping):
     """A dictionary that maintains ``Http-Header-Case`` for all keys.
 
     Supports multiple values per key via a pair of new methods,
@@ -127,9 +127,7 @@ class HTTPHeaders(dict):
     Set-Cookie: C=D
     """
     def __init__(self, *args, **kwargs):
-        # Don't pass args or kwargs to dict.__init__, as it will bypass
-        # our __setitem__
-        dict.__init__(self)
+        self._dict = {}
         self._as_list = {}
         self._last_key = None
         if (len(args) == 1 and len(kwargs) == 0 and
@@ -148,10 +146,8 @@ class HTTPHeaders(dict):
         norm_name = _normalized_headers[name]
         self._last_key = norm_name
         if norm_name in self:
-            # bypass our override of __setitem__ since it modifies _as_list
-            dict.__setitem__(self, norm_name,
-                             native_str(self[norm_name]) + ',' +
-                             native_str(value))
+            self._dict[norm_name] = (native_str(self[norm_name]) + ',' +
+                                     native_str(value))
             self._as_list[norm_name].append(value)
         else:
             self[norm_name] = value
@@ -183,8 +179,7 @@ class HTTPHeaders(dict):
             # continuation of a multi-line header
             new_part = ' ' + line.lstrip()
             self._as_list[self._last_key][-1] += new_part
-            dict.__setitem__(self, self._last_key,
-                             self[self._last_key] + new_part)
+            self._dict[self._last_key] += new_part
         else:
             name, value = line.split(":", 1)
             self.add(name, value.strip())
@@ -203,54 +198,36 @@ class HTTPHeaders(dict):
                 h.parse_line(line)
         return h
 
-    # dict implementation overrides
+    # MutableMapping abstract method implementations.
 
     def __setitem__(self, name, value):
         norm_name = _normalized_headers[name]
-        dict.__setitem__(self, norm_name, value)
+        self._dict[norm_name] = value
         self._as_list[norm_name] = [value]
 
     def __getitem__(self, name):
-        return dict.__getitem__(self, _normalized_headers[name])
+        return self._dict[_normalized_headers[name]]
 
     def __delitem__(self, name):
         norm_name = _normalized_headers[name]
-        dict.__delitem__(self, norm_name)
+        del self._dict[norm_name]
         del self._as_list[norm_name]
 
-    def __contains__(self, name):
-        norm_name = _normalized_headers[name]
-        return dict.__contains__(self, norm_name)
-
-    def get(self, name, default=None):
-        return dict.get(self, _normalized_headers[name], default)
+    def __len__(self):
+        return len(self._dict)
 
-    def update(self, *args, **kwargs):
-        # dict.update bypasses our __setitem__
-        for k, v in dict(*args, **kwargs).items():
-            self[k] = v
+    def __iter__(self):
+        return iter(self._dict)
 
     def copy(self):
-        # default implementation returns dict(self), not the subclass
+        # defined in dict but not in MutableMapping.
         return HTTPHeaders(self)
 
     # Use our overridden copy method for the copy.copy module.
+    # This makes shallow copies one level deeper, but preserves
+    # the appearance that HTTPHeaders is a single container.
     __copy__ = copy
 
-    def __deepcopy__(self, memo_dict):
-        # Our values are immutable strings, so our standard copy is
-        # effectively a deep copy.
-        return self.copy()
-
-    def __reduce_ex__(self, v):
-        # We must override dict.__reduce_ex__ to pickle ourselves
-        # correctly.
-        return HTTPHeaders, (), list(self.get_all())
-
-    def __setstate__(self, state):
-        for k, v in state:
-            self.add(k, v)
-
 
 class HTTPServerRequest(object):
     """A single HTTP request.
index ca60b45ef7f560b868a5c2fa81193630fc77656b..b74fdcbf1404305e6a5345e7ca71a13c33cb2fc0 100644 (file)
@@ -308,6 +308,16 @@ Foo: even
         self.assertEqual(sorted(headers.get_all()), sorted(unpickled.get_all()))
         self.assertEqual(sorted(headers.items()), sorted(unpickled.items()))
 
+    def test_setdefault(self):
+        headers = HTTPHeaders()
+        headers['foo'] = 'bar'
+        # If a value is present, setdefault returns it without changes.
+        self.assertEqual(headers.setdefault('foo', 'baz'), 'bar')
+        self.assertEqual(headers['foo'], 'bar')
+        # If a value is not present, setdefault sets it for future use.
+        self.assertEqual(headers.setdefault('quux', 'xyzzy'), 'xyzzy')
+        self.assertEqual(headers['quux'], 'xyzzy')
+        self.assertEqual(sorted(headers.get_all()), [('Foo', 'bar'), ('Quux', 'xyzzy')])
 
 class FormatTimestampTest(unittest.TestCase):
     # Make sure that all the input types are supported.