]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
- Plural query/body arg methods.
authorTravis Beauvais <tbeauvais@gmail.com>
Sat, 5 Oct 2013 16:45:45 +0000 (09:45 -0700)
committerTravis Beauvais <tbeauvais@gmail.com>
Sat, 5 Oct 2013 16:45:45 +0000 (09:45 -0700)
- Updated Tests.
- Don't change signature of public methods.

docs/web.rst
tornado/escape.py [changed mode: 0644->0755]
tornado/httpserver.py
tornado/httputil.py
tornado/test/httputil_test.py
tornado/test/runtests.py [changed mode: 0644->0755]
tornado/test/web_test.py
tornado/web.py
tornado/wsgi.py

index 14a1f9dea48aa15766d140c0b65eaa0d1f815d70..d9a43c7c2991f5a173521fe072e39aaa72ccbc5a 100644 (file)
 
    .. automethod:: RequestHandler.get_argument
    .. automethod:: RequestHandler.get_arguments
+   .. automethod:: RequestHandler.get_query_argument
+   .. automethod:: RequestHandler.get_query_arguments
+   .. automethod:: RequestHandler.get_body_argument
+   .. automethod:: RequestHandler.get_body_arguments
    .. automethod:: RequestHandler.decode_argument
    .. attribute:: RequestHandler.request
 
old mode 100644 (file)
new mode 100755 (executable)
index eba184a5fa8c7426e66023db2ba17a1ced80c478..50c9074160d8a2e6ed6be0a3bc18e2edee678334 100755 (executable)
@@ -337,7 +337,10 @@ class HTTPConnection(object):
         if self._request.method in ("POST", "PATCH", "PUT"):
             httputil.parse_body_arguments(
                 self._request.headers.get("Content-Type", ""), data,
-                self._request.arguments, self._request.body_arguments, self._request.files)
+                self._request.body_arguments, self._request.files)
+
+            for k, v in self._request.body_arguments.iteritems():
+                self._request.arguments.setdefault(k, []).extend(v)
         self.request_callback(self._request)
 
 
index afd1bb2430404fed440eb81971a52c6156b3c675..3e7337d99818dc9c6b3145452374a4fe2f5ec754 100755 (executable)
@@ -310,7 +310,7 @@ def _int_or_none(val):
     return int(val)
 
 
-def parse_body_arguments(content_type, body, arguments, body_arguments, files):
+def parse_body_arguments(content_type, body, arguments, files):
     """Parses a form request body.
 
     Supports ``application/x-www-form-urlencoded`` and
@@ -324,19 +324,18 @@ def parse_body_arguments(content_type, body, arguments, body_arguments, files):
         for name, values in uri_arguments.items():
             if values:
                 arguments.setdefault(name, []).extend(values)
-                body_arguments.setdefault(name, []).extend(values)
     elif content_type.startswith("multipart/form-data"):
         fields = content_type.split(";")
         for field in fields:
             k, sep, v = field.strip().partition("=")
             if k == "boundary" and v:
-                parse_multipart_form_data(utf8(v), body, arguments, body_arguments, files)
+                parse_multipart_form_data(utf8(v), body, arguments, files)
                 break
         else:
             gen_log.warning("Invalid multipart/form-data")
 
 
-def parse_multipart_form_data(boundary, data, arguments, body_arguments, files):
+def parse_multipart_form_data(boundary, data, arguments, files):
     """Parses a ``multipart/form-data`` body.
 
     The ``boundary`` and ``data`` parameters are both byte strings.
@@ -380,7 +379,6 @@ def parse_multipart_form_data(boundary, data, arguments, body_arguments, files):
                 content_type=ctype))
         else:
             arguments.setdefault(name, []).append(value)
-            body_arguments.setdefault(name, []).append(value)
 
 
 def format_timestamp(ts):
index 7a7f75542c53df51147bbfd040de52e381030488..1e84da76f0b8c74c61d7fee6bebfbd376894c468 100755 (executable)
@@ -74,9 +74,8 @@ Content-Disposition: form-data; name="files"; filename="ab.txt"
 Foo
 --1234--""".replace(b"\n", b"\r\n")
         args = {}
-        body_args = {}
         files = {}
-        parse_multipart_form_data(b"1234", data, args, body_args, files)
+        parse_multipart_form_data(b"1234", data, args, files)
         file = files["files"][0]
         self.assertEqual(file["filename"], "ab.txt")
         self.assertEqual(file["body"], b"Foo")
@@ -90,9 +89,8 @@ Content-Disposition: form-data; name=files; filename=ab.txt
 Foo
 --1234--""".replace(b"\n", b"\r\n")
         args = {}
-        body_args = {}
         files = {}
-        parse_multipart_form_data(b"1234", data, args, body_args, files)
+        parse_multipart_form_data(b"1234", data, args, files)
         file = files["files"][0]
         self.assertEqual(file["filename"], "ab.txt")
         self.assertEqual(file["body"], b"Foo")
@@ -116,9 +114,8 @@ Foo
 --1234--""" % filename.replace('\\', '\\\\').replace('"', '\\"')
             data = utf8(data.replace("\n", "\r\n"))
             args = {}
-            body_args = {}
             files = {}
-            parse_multipart_form_data(b"1234", data, args, body_args, files)
+            parse_multipart_form_data(b"1234", data, args, files)
             file = files["files"][0]
             self.assertEqual(file["filename"], filename)
             self.assertEqual(file["body"], b"Foo")
@@ -131,9 +128,8 @@ Content-Disposition: form-data; name="files"; filename="ab.txt"
 Foo
 --1234--'''.replace(b"\n", b"\r\n")
         args = {}
-        body_args = {}
         files = {}
-        parse_multipart_form_data(b'"1234"', data, args, body_args, files)
+        parse_multipart_form_data(b'"1234"', data, args, files)
         file = files["files"][0]
         self.assertEqual(file["filename"], "ab.txt")
         self.assertEqual(file["body"], b"Foo")
@@ -145,10 +141,9 @@ Foo
 Foo
 --1234--'''.replace(b"\n", b"\r\n")
         args = {}
-        body_args = {}
         files = {}
         with ExpectLog(gen_log, "multipart/form-data missing headers"):
-            parse_multipart_form_data(b"1234", data, args, body_args, files)
+            parse_multipart_form_data(b"1234", data, args, files)
         self.assertEqual(files, {})
 
     def test_invalid_content_disposition(self):
@@ -159,10 +154,9 @@ Content-Disposition: invalid; name="files"; filename="ab.txt"
 Foo
 --1234--'''.replace(b"\n", b"\r\n")
         args = {}
-        body_args = {}
         files = {}
         with ExpectLog(gen_log, "Invalid multipart/form-data"):
-            parse_multipart_form_data(b"1234", data, args, body_args, files)
+            parse_multipart_form_data(b"1234", data, args, files)
         self.assertEqual(files, {})
 
     def test_line_does_not_end_with_correct_line_break(self):
@@ -172,10 +166,9 @@ Content-Disposition: form-data; name="files"; filename="ab.txt"
 
 Foo--1234--'''.replace(b"\n", b"\r\n")
         args = {}
-        body_args = {}
         files = {}
         with ExpectLog(gen_log, "Invalid multipart/form-data"):
-            parse_multipart_form_data(b"1234", data, args, body_args, files)
+            parse_multipart_form_data(b"1234", data, args, files)
         self.assertEqual(files, {})
 
     def test_content_disposition_header_without_name_parameter(self):
@@ -186,10 +179,9 @@ Content-Disposition: form-data; filename="ab.txt"
 Foo
 --1234--""".replace(b"\n", b"\r\n")
         args = {}
-        body_args = {}
         files = {}
         with ExpectLog(gen_log, "multipart/form-data value missing name"):
-            parse_multipart_form_data(b"1234", data, args, body_args, files)
+            parse_multipart_form_data(b"1234", data, args, files)
         self.assertEqual(files, {})
 
     def test_data_after_final_boundary(self):
@@ -204,9 +196,8 @@ Foo
 --1234--
 """.replace(b"\n", b"\r\n")
         args = {}
-        body_args = {}
         files = {}
-        parse_multipart_form_data(b"1234", data, args, body_args, files)
+        parse_multipart_form_data(b"1234", data, args, files)
         file = files["files"][0]
         self.assertEqual(file["filename"], "ab.txt")
         self.assertEqual(file["body"], b"Foo")
old mode 100644 (file)
new mode 100755 (executable)
index f7e06d709c098c2802fcd62ff8382ed398487946..56fc0d83e70b45bbd85e61c111bc59ec1d14aab3 100644 (file)
@@ -20,6 +20,11 @@ import re
 import socket
 import sys
 
+try:
+    import urllib.parse as urllib_parse  # py3
+except ImportError:
+    import urllib as urllib_parse  # py2
+
 wsgi_safe_tests = []
 
 relpath = lambda *a: os.path.join(os.path.dirname(__file__), *a)
@@ -288,7 +293,7 @@ class EchoHandler(RequestHandler):
                 if type(value) != bytes_type:
                     raise Exception("incorrect type for value: %r" %
                                     type(value))
-            for value in self.get_arguments(key, self.request.arguments):
+            for value in self.get_arguments(key):
                 if type(value) != unicode_type:
                     raise Exception("incorrect type for value: %r" %
                                     type(value))
@@ -482,6 +487,19 @@ class GetArgumentHandler(RequestHandler):
     def get(self):
         self.write(self.get_argument("foo", "default"))
 
+    def post(self):
+        self.write(self.get_argument("foo", "default"))
+
+
+class GetQueryArgumentHandler(RequestHandler):
+    def post(self):
+        self.write(self.get_query_argument("foo", "default"))
+
+
+class GetBodyArgumentHandler(RequestHandler):
+    def post(self):
+        self.write(self.get_body_argument("foo", "default"))
+
 
 # This test is shared with wsgi_test.py
 @wsgi_safe
@@ -521,6 +539,8 @@ class WSGISafeWebTest(WebTestCase):
             url("/redirect", RedirectHandler),
             url("/header_injection", HeaderInjectionHandler),
             url("/get_argument", GetArgumentHandler),
+            url("/get_query_argument", GetQueryArgumentHandler),
+            url("/get_body_argument", GetBodyArgumentHandler),
         ]
         return urls
 
@@ -647,6 +667,36 @@ js_embed()
         response = self.fetch("/get_argument")
         self.assertEqual(response.body, b"default")
 
+        # test merging of query and body arguments
+        # body arguments overwrite query arguments
+        body = urllib_parse.urlencode(dict(foo="hello"))
+        response = self.fetch("/get_argument?foo=bar", method="POST", body=body)
+        self.assertEqual(response.body, b"hello")
+
+    def test_get_query_arguments(self):
+        # send as a post so we can ensure the separation between query
+        # string and body arguments.
+        body = urllib_parse.urlencode(dict(foo="hello"))
+        response = self.fetch("/get_query_argument?foo=bar", method="POST", body=body)
+        self.assertEqual(response.body, b"bar")
+        response = self.fetch("/get_query_argument?foo=", method="POST", body=body)
+        self.assertEqual(response.body, b"")
+        response = self.fetch("/get_query_argument", method="POST", body=body)
+        self.assertEqual(response.body, b"default")
+
+    def test_get_body_arguments(self):
+        body = urllib_parse.urlencode(dict(foo="bar"))
+        response = self.fetch("/get_body_argument?foo=hello", method="POST", body=body)
+        self.assertEqual(response.body, b"bar")
+
+        body = urllib_parse.urlencode(dict(foo=""))
+        response = self.fetch("/get_body_argument?foo=hello", method="POST", body=body)
+        self.assertEqual(response.body, b"")
+
+        body = urllib_parse.urlencode(dict())
+        response = self.fetch("/get_body_argument?foo=hello", method="POST", body=body)
+        self.assertEqual(response.body, b"default")
+
     def test_no_gzip(self):
         response = self.fetch('/get_argument')
         self.assertNotIn('Accept-Encoding', response.headers.get('Vary', ''))
index 52e9137ddabd4c5dd8bef62b8eb775be0c6b392d..6c88ca4ec2e0e436ed38ea6a10ef34dd99749c41 100755 (executable)
@@ -348,7 +348,16 @@ class RequestHandler(object):
 
         The returned value is always unicode.
         """
-        return self._get_argument(name, self.request.arguments, default, strip)
+        return self._get_argument(name, default, self.request.arguments, strip)
+
+    def get_arguments(self, name, strip=True):
+        """Returns a list of the arguments with the given name.
+
+        If the argument is not present, returns an empty list.
+
+        The returned values are always unicode.
+        """
+        return self._get_arguments(name, self.request.arguments, strip)
 
     def get_body_argument(self, name, default=_ARG_DEFAULT, strip=True):
         """Returns the value of the argument with the given name
@@ -362,7 +371,16 @@ class RequestHandler(object):
 
         The returned value is always unicode.
         """
-        return self._get_argument(name, self.request.body_arguments, default, strip)
+        return self._get_argument(name, default, self.request.body_arguments, strip)
+
+    def get_body_arguments(self, name, strip=True):
+        """Returns a list of the body arguments with the given name.
+
+        If the argument is not present, returns an empty list.
+
+        The returned values are always unicode.
+        """
+        return self._get_arguments(name, self.request.body_arguments, strip)
 
     def get_query_argument(self, name, default=_ARG_DEFAULT, strip=True):
         """Returns the value of the argument with the given name
@@ -376,24 +394,26 @@ class RequestHandler(object):
 
         The returned value is always unicode.
         """
-        return self._get_argument(name, self.request.query_arguments, default, strip)
+        return self._get_argument(name, default, self.request.query_arguments, strip)
 
-    def _get_argument(self, name, source, default=_ARG_DEFAULT, strip=True):
-        args = self.get_arguments(name, source, strip=strip)
-        if not args:
-            if default is self._ARG_DEFAULT:
-                raise MissingArgumentError(name)
-            return default
-        return args[-1]
-
-    def get_arguments(self, name, source, strip=True):
-        """Returns a list of the arguments with the given name.
+    def get_query_arguments(self, name, strip=True):
+        """Returns a list of the query arguments with the given name.
 
         If the argument is not present, returns an empty list.
 
         The returned values are always unicode.
         """
+        return self._get_arguments(name, self.request.query_arguments, strip)
+
+    def _get_argument(self, name, default, source, strip=True):
+        args = self._get_arguments(name, source, strip=strip)
+        if not args:
+            if default is self._ARG_DEFAULT:
+                raise MissingArgumentError(name)
+            return default
+        return args[-1]
 
+    def _get_arguments(self, name, source, strip=True):
         values = []
         for v in source.get(name, []):
             v = self.decode_argument(v, name=name)
index f85e86cf9cd3b4821819b3e134ccdb7fd7f7c602..2c0e38d6d294e824d8fd0a44eb6f7d9fe8aa2364 100755 (executable)
@@ -175,7 +175,7 @@ class HTTPRequest(object):
         # Parse request body
         self.files = {}
         httputil.parse_body_arguments(self.headers.get("Content-Type", ""),
-                                      self.body, self.arguments, self.body_arguments, self.files)
+                                      self.body, self.arguments, self.files)
 
         self._start_time = time.time()
         self._finish_time = None