]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Type checks for httpserver.HTTPRequest fields
authorBen Darnell <ben@bendarnell.com>
Mon, 30 May 2011 01:39:23 +0000 (18:39 -0700)
committerBen Darnell <ben@bendarnell.com>
Mon, 30 May 2011 01:39:23 +0000 (18:39 -0700)
tornado/httpserver.py
tornado/test/httpserver_test.py

index 05f43f46579fb492d6c8fb987f4436e6232de63a..bcfdc78fa6ef0e1b02484eaba489f526ffad83eb 100644 (file)
@@ -398,9 +398,8 @@ class HTTPConnection(object):
         content_type = self._request.headers.get("Content-Type", "")
         if self._request.method in ("POST", "PUT"):
             if content_type.startswith("application/x-www-form-urlencoded"):
-                arguments = parse_qs(self._request.body)
+                arguments = parse_qs(native_str(self._request.body))
                 for name, values in arguments.iteritems():
-                    name = name.decode('utf-8')
                     values = [v for v in values if v]
                     if values:
                         self._request.arguments.setdefault(name, []).extend(
index 952161287f0806918a16fede1353d9731ee477fe..4600f77f330d7bbf411c25a97cd536f5189ecde0 100644 (file)
@@ -5,7 +5,7 @@ from tornado.escape import json_decode, utf8, _unicode
 from tornado.iostream import IOStream
 from tornado.simple_httpclient import SimpleAsyncHTTPClient
 from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase
-from tornado.util import b
+from tornado.util import b, bytes_type
 from tornado.web import Application, RequestHandler
 import logging
 import os
@@ -140,11 +140,58 @@ class EchoHandler(RequestHandler):
     def get(self):
         self.write(self.request.arguments)
 
+class TypeCheckHandler(RequestHandler):
+    def prepare(self):
+        self.errors = {}
+        fields = [
+            ('method', str),
+            ('uri', str),
+            ('version', str),
+            ('remote_ip', str),
+            ('protocol', str),
+            ('host', str),
+            ('path', str),
+            ('query', str),
+            ]
+        for field, expected_type in fields:
+            self.check_type(field, getattr(self.request, field), expected_type)
+
+        self.check_type('header_key', self.request.headers.keys()[0], str)
+        self.check_type('header_value', self.request.headers.values()[0], str)
+
+        self.check_type('arg_key', self.request.arguments.keys()[0], str)
+        self.check_type('arg_value', self.request.arguments.values()[0][0], str)
+
+    def post(self):
+        self.check_type('body', self.request.body, bytes_type)
+        self.write(self.errors)
+
+    def get(self):
+        self.write(self.errors)
+
+    def check_type(self, name, obj, expected_type):
+        actual_type = type(obj)
+        if expected_type != actual_type:
+            self.errors[name] = "expected %s, got %s" % (expected_type, 
+                                                         actual_type)
+
 class HTTPServerTest(AsyncHTTPTestCase, LogTrapTestCase):
     def get_app(self):
-        return Application([("/echo", EchoHandler)])
+        return Application([("/echo", EchoHandler),
+                            ("/typecheck", TypeCheckHandler),
+                            ])
 
     def test_query_string_encoding(self):
         response = self.fetch("/echo?foo=%C3%A9")
         data = json_decode(response.body)
         self.assertEqual(data, {u"foo": [u"\u00e9"]})
+
+    def test_types(self):
+        response = self.fetch("/typecheck?foo=bar")
+        data = json_decode(response.body)
+        self.assertEqual(data, {})
+
+        response = self.fetch("/typecheck", method="POST", body="foo=bar")
+        data = json_decode(response.body)
+        self.assertEqual(data, {})
+