]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Allow the application to determine the encoding used for url parameters (previously...
authorBen Darnell <ben@bendarnell.com>
Mon, 30 May 2011 06:32:22 +0000 (23:32 -0700)
committerBen Darnell <ben@bendarnell.com>
Mon, 30 May 2011 06:32:22 +0000 (23:32 -0700)
HTTPRequest.arguments now maps from native strings to bytes.  That's slightly
inconsistent, but having byte strings as dictionary keys is awkward.

tornado/escape.py
tornado/httpserver.py
tornado/test/httpserver_test.py
tornado/test/web_test.py
tornado/web.py

index 645790964944b5c55dee0c9c49d3b1aa6f494f71..74a1caa7eb737d7068e370440a97c3bcd28c8478 100644 (file)
@@ -26,6 +26,11 @@ import urllib
 try: bytes
 except: bytes = str
 
+try:
+    from urlparse import parse_qs  # Python 2.6+
+except ImportError:
+    from cgi import parse_qs
+
 # json module is in the standard library as of python 2.6; fall back to
 # simplejson if present for older versions.
 try:
@@ -88,7 +93,8 @@ def url_escape(value):
     return urllib.quote_plus(utf8(value))
 
 # python 3 changed things around enough that we need two separate
-# implementations of url_unescape
+# implementations of url_unescape.  We also need our own implementation
+# of parse_qs since python 3's version insists on decoding everything.
 if sys.version_info[0] < 3:
     def url_unescape(value, encoding='utf-8'):
         """Decodes the given value from a URL.
@@ -102,6 +108,8 @@ if sys.version_info[0] < 3:
             return urllib.unquote_plus(utf8(value))
         else:
             return unicode(urllib.unquote_plus(utf8(value)), encoding)
+
+    parse_qs_bytes = parse_qs
 else:
     def url_unescape(value, encoding='utf-8'):
         """Decodes the given value from a URL.
@@ -116,6 +124,24 @@ else:
         else:
             return urllib.unquote_plus(native_str(value), encoding=encoding)
 
+    def parse_qs_bytes(qs, keep_blank_values=False, strict_parsing=False):
+        """Parses a query string like urlparse.parse_qs, but returns the
+        values as byte strings.
+
+        Keys still become type str (interpreted as latin1 in python3!)
+        because it's too painful to keep them as byte strings in
+        python3 and in practice they're nearly always ascii anyway.
+        """
+        # This is gross, but python3 doesn't give us another way.
+        # Latin1 is the universal donor of character encodings.
+        result = parse_qs(qs, keep_blank_values, strict_parsing,
+                          encoding='latin1', errors='strict')
+        encoded = {}
+        for k,v in result.iteritems():
+            encoded[k] = [i.encode('latin1') for i in v]
+        return encoded
+        
+
 
 _UTF8_TYPES = (bytes, type(None))
 def utf8(value):
@@ -153,6 +179,22 @@ else:
     native_str = utf8
 
 
+def recursive_unicode(obj):
+    """Walks a simple data structure, converting byte strings to unicode.
+
+    Supports lists, tuples, and dictionaries.
+    """
+    if isinstance(obj, dict):
+        return dict((recursive_unicode(k), recursive_unicode(v)) for (k,v) in obj.iteritems())
+    elif isinstance(obj, list):
+        return list(recursive_unicode(i) for i in obj)
+    elif isinstance(obj, tuple):
+        return tuple(recursive_unicode(i) for i in obj)
+    elif isinstance(obj, bytes):
+        return to_unicode(obj)
+    else:
+        return obj
+
 # I originally used the regex from 
 # http://daringfireball.net/2010/07/improved_regex_for_matching_urls
 # but it gets all exponential on certain patterns (such as too many trailing
index bcfdc78fa6ef0e1b02484eaba489f526ffad83eb..f5ab0e6c848bb4a9049b872e023599381efe68b0 100644 (file)
@@ -23,18 +23,13 @@ import socket
 import time
 import urlparse
 
-from tornado.escape import utf8, native_str
+from tornado.escape import utf8, native_str, parse_qs_bytes
 from tornado import httputil
 from tornado import ioloop
 from tornado import iostream
 from tornado import stack_context
 from tornado.util import b, bytes_type
 
-try:
-    from urlparse import parse_qs  # Python 2.6+
-except ImportError:
-    from cgi import parse_qs
-
 try:
     import fcntl
 except ImportError:
@@ -398,7 +393,7 @@ class HTTPConnection(object):
         content_type = self._request.headers.get("Content-Type", "")
         if self._request.method in ("POST", "PUT"):
             if content_type.startswith("application/x-www-form-urlencoded"):
-                arguments = parse_qs(native_str(self._request.body))
+                arguments = parse_qs_bytes(native_str(self._request.body))
                 for name, values in arguments.iteritems():
                     values = [v for v in values if v]
                     if values:
@@ -511,7 +506,7 @@ class HTTPRequest(object):
         scheme, netloc, path, query, fragment = urlparse.urlsplit(native_str(uri))
         self.path = path
         self.query = query
-        arguments = parse_qs(query)
+        arguments = parse_qs_bytes(query)
         self.arguments = {}
         for name, values in arguments.iteritems():
             values = [v for v in values if v]
index 4600f77f330d7bbf411c25a97cd536f5189ecde0..def4f0fd2f292c73ca2de261b42cbf90cc0b61e2 100644 (file)
@@ -1,7 +1,7 @@
 #!/usr/bin/env python
 
 from tornado import httpclient, simple_httpclient
-from tornado.escape import json_decode, utf8, _unicode
+from tornado.escape import json_decode, utf8, _unicode, recursive_unicode
 from tornado.iostream import IOStream
 from tornado.simple_httpclient import SimpleAsyncHTTPClient
 from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase
@@ -138,7 +138,7 @@ class HTTPConnectionTest(AsyncHTTPTestCase, LogTrapTestCase):
 
 class EchoHandler(RequestHandler):
     def get(self):
-        self.write(self.request.arguments)
+        self.write(recursive_unicode(self.request.arguments))
 
 class TypeCheckHandler(RequestHandler):
     def prepare(self):
@@ -160,7 +160,7 @@ class TypeCheckHandler(RequestHandler):
         self.check_type('header_value', self.request.headers.values()[0], str)
 
         self.check_type('arg_key', self.request.arguments.keys()[0], str)
-        self.check_type('arg_value', self.request.arguments.values()[0][0], str)
+        self.check_type('arg_value', self.request.arguments.values()[0][0], bytes_type)
 
     def post(self):
         self.check_type('body', self.request.body, bytes_type)
index 5e0d40ddc518cab5384873aa2efc5ff011a571da..70b2edb07dee78bf81712895a6648e3ef85a6601 100644 (file)
@@ -1,4 +1,4 @@
-from tornado.escape import json_decode, utf8
+from tornado.escape import json_decode, utf8, to_unicode, recursive_unicode, native_str
 from tornado.iostream import IOStream
 from tornado.testing import LogTrapTestCase, AsyncHTTPTestCase
 from tornado.util import b, bytes_type
@@ -148,18 +148,19 @@ class ConnectionCloseTest(AsyncHTTPTestCase, LogTrapTestCase):
 
 class EchoHandler(RequestHandler):
     def get(self, path):
-        # Type checks:  web.py interfaces convert arguments to unicode
-        # strings.  In httpserver.py (i.e. self.request.arguments),
-        # they're left as the native str type.
+        # Type checks: web.py interfaces convert argument values to
+        # unicode strings (by default, but see also decode_argument).
+        # In httpserver.py (i.e. self.request.arguments), they're left
+        # as bytes.  Keys are always native strings.
         for key in self.request.arguments:
-            assert type(key) == type(""), repr(key)
+            assert type(key) == str, repr(key)
             for value in self.request.arguments[key]:
-                assert type(value) == type(""), repr(value)
+                assert type(value) == bytes_type, repr(value)
             for value in self.get_arguments(key):
                 assert type(value) == unicode, repr(value)
         assert type(path) == unicode, repr(path)
         self.write(dict(path=path,
-                        args=self.request.arguments))
+                        args=recursive_unicode(self.request.arguments)))
 
 class RequestEncodingTest(AsyncHTTPTestCase, LogTrapTestCase):
     def get_app(self):
@@ -215,9 +216,33 @@ class TypeCheckHandler(RequestHandler):
             self.errors[name] = "expected %s, got %s" % (expected_type,
                                                          actual_type)
 
+class DecodeArgHandler(RequestHandler):
+    def decode_argument(self, value, name=None):
+        assert type(value) == bytes_type, repr(value)
+        # use self.request.arguments directly to avoid recursion
+        if 'encoding' in self.request.arguments:
+            return value.decode(to_unicode(self.request.arguments['encoding'][0]))
+        else:
+            return value
+
+    def get(self, arg):
+        def describe(s):
+            if type(s) == bytes_type:
+                return ["bytes", native_str(binascii.b2a_hex(s))]
+            elif type(s) == unicode:
+                return ["unicode", s]
+            raise Exception("unknown type")
+        self.write({'path': describe(arg),
+                    'query': describe(self.get_argument("foo")),
+                    })
+
 class WebTest(AsyncHTTPTestCase, LogTrapTestCase):
     def get_app(self):
-        return Application([url("/typecheck/(.*)", TypeCheckHandler, name='typecheck')])
+        return Application([
+                url("/typecheck/(.*)", TypeCheckHandler, name='typecheck'),
+                url("/decode_arg/(.*)", DecodeArgHandler),
+                url("/decode_arg_kw/(?P<arg>.*)", DecodeArgHandler),
+                ])
 
     def test_types(self):
         response = self.fetch("/typecheck/asdf?foo=bar",
@@ -228,3 +253,24 @@ class WebTest(AsyncHTTPTestCase, LogTrapTestCase):
         response = self.fetch("/typecheck/asdf?foo=bar", method="POST",
                               headers={"Cookie": "cook=ie"},
                               body="foo=bar")
+
+    def test_decode_argument(self):
+        # These urls all decode to the same thing
+        urls = ["/decode_arg/%C3%A9?foo=%C3%A9&encoding=utf-8",
+                "/decode_arg/%E9?foo=%E9&encoding=latin1",
+                "/decode_arg_kw/%E9?foo=%E9&encoding=latin1",
+                ]
+        for url in urls:
+            response = self.fetch(url)
+            response.rethrow()
+            data = json_decode(response.body)
+            self.assertEqual(data, {u'path': [u'unicode', u'\u00e9'],
+                                    u'query': [u'unicode', u'\u00e9'],
+                                    })
+
+        response = self.fetch("/decode_arg/%C3%A9?foo=%C3%A9")
+        response.rethrow()
+        data = json_decode(response.body)
+        self.assertEqual(data, {u'path': [u'bytes', u'c3a9'],
+                                u'query': [u'bytes', u'c3a9'],
+                                })
index d631b15bc390019822510c7b072404900c79d052..63e1ccf165d3dc89159a9cc8c3ce134cfc721c48 100644 (file)
@@ -249,14 +249,32 @@ class RequestHandler(object):
 
         The returned values are always unicode.
         """
-        values = self.request.arguments.get(name, [])
-        # Get rid of any weird control chars
-        values = [re.sub(r"[\x00-\x08\x0e-\x1f]", " ", _unicode(x)) 
-                  for x in values]
-        if strip:
-            values = [x.strip() for x in values]
+        values = []
+        for v in self.request.arguments.get(name, []):
+            v = self.decode_argument(v, name=name)
+            if isinstance(v, unicode):
+                # Get rid of any weird control chars (unless decoding gave
+                # us bytes, in which case leave it alone)
+                v = re.sub(r"[\x00-\x08\x0e-\x1f]", " ", v)
+            if strip:
+                v = v.strip()
+            values.append(v)
         return values
 
+    def decode_argument(self, value, name=None):
+        """Decodes an argument from the request.
+
+        The argument has been percent-decoded and is now a byte string.
+        By default, this method decodes the argument as utf-8 and returns
+        a unicode string, but this may be overridden in subclasses.
+
+        This method is used as a filter for both get_argument() and for
+        values extracted from the url and passed to get()/post()/etc.
+
+        The name of the argument is provided if known, but may be None
+        (e.g. for unnamed groups in the url regex).
+        """
+        return _unicode(value)
 
     @property
     def cookies(self):
@@ -881,6 +899,9 @@ class RequestHandler(object):
                 self.check_xsrf_cookie()
             self.prepare()
             if not self._finished:
+                args = [self.decode_argument(arg) for arg in args]
+                kwargs = dict((k, self.decode_argument(v, name=k))
+                              for (k,v) in kwargs.iteritems())
                 getattr(self, self.request.method.lower())(*args, **kwargs)
                 if self._auto_finish and not self._finished:
                     self.finish()
@@ -1198,15 +1219,17 @@ class Application(object):
             for spec in handlers:
                 match = spec.regex.match(request.path)
                 if match:
-                    # None-safe wrapper around urllib.unquote to handle
+                    # None-safe wrapper around url_unescape to handle
                     # unmatched optional groups correctly
                     def unquote(s):
                         if s is None: return s
-                        return _unicode(urllib.unquote(s))
+                        return escape.url_unescape(s, encoding=None)
                     handler = spec.handler_class(self, request, **spec.kwargs)
                     # Pass matched groups to the handler.  Since
                     # match.groups() includes both named and unnamed groups,
                     # we want to use either groups or groupdict but not both.
+                    # Note that args are passed as bytes so the handler can
+                    # decide what encoding to use.
                     kwargs = dict((k, unquote(v))
                                   for (k, v) in match.groupdict().iteritems())
                     if kwargs: