]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Expand argument-origin tests to include the plural methods.
authorBen Darnell <ben@bendarnell.com>
Sun, 6 Oct 2013 00:26:14 +0000 (20:26 -0400)
committerBen Darnell <ben@bendarnell.com>
Sun, 6 Oct 2013 00:26:14 +0000 (20:26 -0400)
tornado/test/web_test.py

index 56fc0d83e70b45bbd85e61c111bc59ec1d14aab3..153c8ccbf7deb4548e2f76029bc2591ecfcccd6c 100644 (file)
@@ -484,21 +484,21 @@ class HeaderInjectionHandler(RequestHandler):
 
 
 class GetArgumentHandler(RequestHandler):
-    def get(self):
-        self.write(self.get_argument("foo", "default"))
-
-    def post(self):
-        self.write(self.get_argument("foo", "default"))
-
-
-class GetQueryArgumentHandler(RequestHandler):
-    def post(self):
-        self.write(self.get_query_argument("foo", "default"))
+    def prepare(self):
+        if self.get_argument('source', None) == 'query':
+            method = self.get_query_argument
+        elif self.get_argument('source', None) == 'body':
+            method = self.get_body_argument
+        else:
+            method = self.get_argument
+        self.finish(method("foo", "default"))
 
 
-class GetBodyArgumentHandler(RequestHandler):
-    def post(self):
-        self.write(self.get_body_argument("foo", "default"))
+class GetArgumentsHandler(RequestHandler):
+    def prepare(self):
+        self.finish(dict(default=self.get_arguments("foo"),
+                         query=self.get_query_arguments("foo"),
+                         body=self.get_body_arguments("foo")))
 
 
 # This test is shared with wsgi_test.py
@@ -539,8 +539,7 @@ class WSGISafeWebTest(WebTestCase):
             url("/redirect", RedirectHandler),
             url("/header_injection", HeaderInjectionHandler),
             url("/get_argument", GetArgumentHandler),
-            url("/get_query_argument", GetQueryArgumentHandler),
-            url("/get_body_argument", GetBodyArgumentHandler),
+            url("/get_arguments", GetArgumentsHandler),
         ]
         return urls
 
@@ -667,34 +666,47 @@ js_embed()
         response = self.fetch("/get_argument")
         self.assertEqual(response.body, b"default")
 
-        # test merging of query and body arguments
-        # body arguments overwrite query arguments
+        # Test merging of query and body arguments.
+        # In singular form, body arguments take precedence over query arguments.
         body = urllib_parse.urlencode(dict(foo="hello"))
         response = self.fetch("/get_argument?foo=bar", method="POST", body=body)
         self.assertEqual(response.body, b"hello")
+        # In plural methods they are merged.
+        response = self.fetch("/get_arguments?foo=bar",
+                              method="POST", body=body)
+        self.assertEqual(json_decode(response.body),
+                         dict(default=['bar', 'hello'],
+                              query=['bar'],
+                              body=['hello']))
 
     def test_get_query_arguments(self):
         # send as a post so we can ensure the separation between query
         # string and body arguments.
         body = urllib_parse.urlencode(dict(foo="hello"))
-        response = self.fetch("/get_query_argument?foo=bar", method="POST", body=body)
+        response = self.fetch("/get_argument?source=query&foo=bar",
+                              method="POST", body=body)
         self.assertEqual(response.body, b"bar")
-        response = self.fetch("/get_query_argument?foo=", method="POST", body=body)
+        response = self.fetch("/get_argument?source=query&foo=",
+                              method="POST", body=body)
         self.assertEqual(response.body, b"")
-        response = self.fetch("/get_query_argument", method="POST", body=body)
+        response = self.fetch("/get_argument?source=query",
+                              method="POST", body=body)
         self.assertEqual(response.body, b"default")
 
     def test_get_body_arguments(self):
         body = urllib_parse.urlencode(dict(foo="bar"))
-        response = self.fetch("/get_body_argument?foo=hello", method="POST", body=body)
+        response = self.fetch("/get_argument?source=body&foo=hello",
+                              method="POST", body=body)
         self.assertEqual(response.body, b"bar")
 
         body = urllib_parse.urlencode(dict(foo=""))
-        response = self.fetch("/get_body_argument?foo=hello", method="POST", body=body)
+        response = self.fetch("/get_argument?source=body&foo=hello",
+                              method="POST", body=body)
         self.assertEqual(response.body, b"")
 
         body = urllib_parse.urlencode(dict())
-        response = self.fetch("/get_body_argument?foo=hello", method="POST", body=body)
+        response = self.fetch("/get_argument?source=body&foo=hello",
+                              method="POST", body=body)
         self.assertEqual(response.body, b"default")
 
     def test_no_gzip(self):