]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Test multipart/form-data parsing in wsgi and fix it for python3
authorBen Darnell <ben@bendarnell.com>
Sun, 5 Jun 2011 19:58:26 +0000 (12:58 -0700)
committerBen Darnell <ben@bendarnell.com>
Sun, 5 Jun 2011 19:58:26 +0000 (12:58 -0700)
tornado/test/wsgi_test.py
tornado/wsgi.py

index f2236cea6b339f6fa6bbfa9ae63b9bb9dda03cfd..377a36f9eddca575c084938653da2aae33f4735d 100644 (file)
@@ -1,5 +1,6 @@
 from wsgiref.validate import validator
 
+from tornado.escape import json_encode
 from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase
 from tornado.util import b
 from tornado.web import RequestHandler
@@ -35,3 +36,15 @@ class WSGIApplicationTest(AsyncHTTPTestCase, LogTrapTestCase):
     def test_simple(self):
         response = self.fetch("/")
         self.assertEqual(response.body, b("Hello world!"))
+
+# This is kind of hacky, but run some of the HTTPServer tests through
+# WSGIContainer and WSGIApplication to make sure everything survives
+# repeated disassembly and reassembly.
+from tornado.test.httpserver_test import HTTPConnectionTest, MultipartTestHandler
+
+class WSGIConnectionTest(HTTPConnectionTest):
+    def get_app(self):
+        return WSGIContainer(validator(WSGIApplication([
+                        ("/multipart", MultipartTestHandler)])))
+
+del HTTPConnectionTest
index 1a62924ab944fc20b3f045e5ae7e28c4c765138f..20ec8975c1d9c8c570cb24610aa8591993924cf2 100644 (file)
@@ -61,7 +61,7 @@ import urllib
 from tornado import escape
 from tornado import httputil
 from tornado import web
-from tornado.escape import native_str
+from tornado.escape import native_str, utf8
 from tornado.util import b
 
 try:
@@ -139,7 +139,7 @@ class HTTPRequest(object):
         elif content_type.startswith("multipart/form-data"):
             if 'boundary=' in content_type:
                 boundary = content_type.split('boundary=',1)[1]
-                if boundary: self._parse_mime_body(boundary)
+                if boundary: self._parse_mime_body(utf8(boundary))
             else:
                 logging.warning("Invalid multipart/form-data")
 
@@ -162,30 +162,30 @@ class HTTPRequest(object):
             return self._finish_time - self._start_time
 
     def _parse_mime_body(self, boundary):
-        if boundary.startswith('"') and boundary.endswith('"'):
+        if boundary.startswith(b('"')) and boundary.endswith(b('"')):
             boundary = boundary[1:-1]
-        if self.body.endswith("\r\n"):
+        if self.body.endswith(b("\r\n")):
             footer_length = len(boundary) + 6
         else:
             footer_length = len(boundary) + 4
-        parts = self.body[:-footer_length].split("--" + boundary + "\r\n")
+        parts = self.body[:-footer_length].split(b("--") + boundary + b("\r\n"))
         for part in parts:
             if not part: continue
-            eoh = part.find("\r\n\r\n")
+            eoh = part.find(b("\r\n\r\n"))
             if eoh == -1:
                 logging.warning("multipart/form-data missing headers")
                 continue
-            headers = httputil.HTTPHeaders.parse(part[:eoh])
+            headers = httputil.HTTPHeaders.parse(part[:eoh].decode("utf-8"))
             name_header = headers.get("Content-Disposition", "")
             if not name_header.startswith("form-data;") or \
-               not part.endswith("\r\n"):
+               not part.endswith(b("\r\n")):
                 logging.warning("Invalid multipart/form-data")
                 continue
             value = part[eoh + 4:-2]
             name_values = {}
             for name_part in name_header[10:].split(";"):
                 name, name_value = name_part.strip().split("=", 1)
-                name_values[name] = name_value.strip('"').decode("utf-8")
+                name_values[name] = name_value.strip('"')
             if not name_values.get("name"):
                 logging.warning("multipart/form-data value missing name")
                 continue
@@ -285,9 +285,9 @@ class WSGIContainer(object):
             "wsgi.run_once": False,
         }
         if "Content-Type" in request.headers:
-            environ["CONTENT_TYPE"] = request.headers["Content-Type"]
+            environ["CONTENT_TYPE"] = request.headers.pop("Content-Type")
         if "Content-Length" in request.headers:
-            environ["CONTENT_LENGTH"] = request.headers["Content-Length"]
+            environ["CONTENT_LENGTH"] = request.headers.pop("Content-Length")
         for key, value in request.headers.iteritems():
             environ["HTTP_" + key.replace("-", "_").upper()] = value
         return environ