]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Run more of web_test in wsgi_test. Fix a bug with 304 in wsgi.
authorBen Darnell <ben@bendarnell.com>
Mon, 10 Sep 2012 04:34:58 +0000 (21:34 -0700)
committerBen Darnell <ben@bendarnell.com>
Mon, 10 Sep 2012 04:34:58 +0000 (21:34 -0700)
tornado/test/web_test.py
tornado/test/wsgi_test.py
tornado/wsgi.py

index 3ce700f649fb2c3ef15455c6f8ade0ea4b83ac7f..b5e1698e7590fb748e1f772ca819753981938558 100644 (file)
@@ -16,15 +16,31 @@ import re
 import socket
 import sys
 
+wsgi_safe = []
 
-class SimpleHandlerTestCase(AsyncHTTPTestCase):
+class WebTestCase(AsyncHTTPTestCase):
+    """Base class for web tests that also supports WSGI mode.
+
+    Override get_handlers and get_app_kwargs instead of get_app.
+    Append to wsgi_safe to have it run in wsgi_test as well.
+    """
+    def get_app(self):
+        self.app = Application(self.get_handlers(), **self.get_app_kwargs())
+        return self.app
+
+    def get_handlers(self):
+        raise NotImplementedError()
+
+    def get_app_kwargs(self):
+        return {}
+
+class SimpleHandlerTestCase(WebTestCase):
     """Simplified base class for tests that work with a single handler class.
 
     To use, define a nested class named ``Handler``.
     """
-    def get_app(self):
-        return Application([('/', self.Handler)],
-                           log_function=lambda x: None)
+    def get_handlers(self):
+        return [('/', self.Handler)]
 
 
 class CookieTestRequestHandler(RequestHandler):
@@ -82,8 +98,8 @@ class SecureCookieTest(unittest.TestCase):
         self.assertEqual(handler.get_secure_cookie('foo'), b('\xe9'))
 
 
-class CookieTest(AsyncHTTPTestCase):
-    def get_app(self):
+class CookieTest(WebTestCase):
+    def get_handlers(self):
         class SetCookieHandler(RequestHandler):
             def get(self):
                 # Try setting cookies with different argument types
@@ -117,13 +133,12 @@ class CookieTest(AsyncHTTPTestCase):
                 # Attributes from the first call are not carried over.
                 self.set_cookie("a", "e")
 
-        return Application([
-                ("/set", SetCookieHandler),
+        return [("/set", SetCookieHandler),
                 ("/get", GetCookieHandler),
                 ("/set_domain", SetCookieDomainHandler),
                 ("/special_char", SetCookieSpecialCharHandler),
                 ("/set_overwrite", SetCookieOverwriteHandler),
-                ])
+                ]
 
     def test_set_cookie(self):
         response = self.fetch("/set")
@@ -191,12 +206,12 @@ class AuthRedirectRequestHandler(RequestHandler):
         self.send_error(500)
 
 
-class AuthRedirectTest(AsyncHTTPTestCase):
-    def get_app(self):
-        return Application([('/relative', AuthRedirectRequestHandler,
-                             dict(login_url='/login')),
-                            ('/absolute', AuthRedirectRequestHandler,
-                             dict(login_url='http://example.com/login'))])
+class AuthRedirectTest(WebTestCase):
+    def get_handlers(self):
+        return [('/relative', AuthRedirectRequestHandler,
+                 dict(login_url='/login')),
+                ('/absolute', AuthRedirectRequestHandler,
+                 dict(login_url='http://example.com/login'))]
 
     def test_relative_auth_redirect(self):
         self.http_client.fetch(self.get_url('/relative'), self.stop,
@@ -227,9 +242,9 @@ class ConnectionCloseHandler(RequestHandler):
         self.test.on_connection_close()
 
 
-class ConnectionCloseTest(AsyncHTTPTestCase):
-    def get_app(self):
-        return Application([('/', ConnectionCloseHandler, dict(test=self))])
+class ConnectionCloseTest(WebTestCase):
+    def get_handlers(self):
+        return [('/', ConnectionCloseHandler, dict(test=self))]
 
     def test_connection_close(self):
         s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
@@ -272,12 +287,11 @@ class EchoHandler(RequestHandler):
                         args=recursive_unicode(self.request.arguments)))
 
 
-class RequestEncodingTest(AsyncHTTPTestCase):
-    def get_app(self):
-        return Application([
-                ("/group/(.*)", EchoHandler),
+class RequestEncodingTest(WebTestCase):
+    def get_handlers(self):
+        return [("/group/(.*)", EchoHandler),
                 ("/slashes/([^/]*)/([^/]*)", EchoHandler),
-                ])
+                ]
 
     def fetch_json(self, path):
         return json_decode(self.fetch(path).body)
@@ -457,13 +471,9 @@ class GetArgumentHandler(RequestHandler):
 
 
 # This test is shared with wsgi_test.py
-class WSGISafeWebTest(AsyncHTTPTestCase):
+class WSGISafeWebTest(WebTestCase):
     COOKIE_SECRET = "WebTest.COOKIE_SECRET"
 
-    def get_app(self):
-        self.app = Application(self.get_handlers(), **self.get_app_kwargs())
-        return self.app
-
     def get_app_kwargs(self):
         loader = DictLoader({
                 "linkify.html": "{% module linkify(message) %}",
@@ -604,15 +614,14 @@ js_embed()
         self.assertEqual(response.body, b(""))
         response = self.fetch("/get_argument")
         self.assertEqual(response.body, b("default"))
+wsgi_safe.append(WSGISafeWebTest)
 
 
-class NonWSGIWebTests(AsyncHTTPTestCase):
-    def get_app(self):
-        urls = [
-            ("/flow_control", FlowControlHandler),
-            ("/empty_flush", EmptyFlushCallbackHandler),
-            ]
-        return Application(urls)
+class NonWSGIWebTests(WebTestCase):
+    def get_handlers(self):
+        return [("/flow_control", FlowControlHandler),
+                ("/empty_flush", EmptyFlushCallbackHandler),
+                ]
 
     def test_flow_control(self):
         self.assertEqual(self.fetch("/flow_control").body, b("123"))
@@ -622,8 +631,8 @@ class NonWSGIWebTests(AsyncHTTPTestCase):
         self.assertEqual(response.body, b("ok"))
 
 
-class ErrorResponseTest(AsyncHTTPTestCase):
-    def get_app(self):
+class ErrorResponseTest(WebTestCase):
+    def get_handlers(self):
         class DefaultHandler(RequestHandler):
             def get(self):
                 if self.get_argument("status", None):
@@ -665,12 +674,11 @@ class ErrorResponseTest(AsyncHTTPTestCase):
             def write_error(self, status_code, **kwargs):
                 raise Exception("exception in write_error")
 
-        return Application([
-                url("/default", DefaultHandler),
+        return [url("/default", DefaultHandler),
                 url("/write_error", WriteErrorHandler),
                 url("/get_error_html", GetErrorHtmlHandler),
                 url("/failed_write_error", FailedWriteErrorHandler),
-                ])
+                ]
 
     def test_default(self):
         with ExpectLog(app_log, "Uncaught exception"):
@@ -707,10 +715,10 @@ class ErrorResponseTest(AsyncHTTPTestCase):
             response = self.fetch("/failed_write_error")
             self.assertEqual(response.code, 500)
             self.assertEqual(b(""), response.body)
+wsgi_safe.append(ErrorResponseTest)
 
-
-class StaticFileTest(AsyncHTTPTestCase):
-    def get_app(self):
+class StaticFileTest(WebTestCase):
+    def get_handlers(self):
         class StaticUrlHandler(RequestHandler):
             def get(self, path):
                 self.write(self.static_url(path))
@@ -742,10 +750,13 @@ class StaticFileTest(AsyncHTTPTestCase):
                     result = (check_override == -1 and check_regular == 0)
                 self.write(str(result))
 
-        return Application([('/static_url/(.*)', StaticUrlHandler),
-                            ('/abs_static_url/(.*)', AbsoluteStaticUrlHandler),
-                            ('/override_static_url/(.*)', OverrideStaticUrlHandler)],
-                           static_path=os.path.join(os.path.dirname(__file__), 'static'))
+        return [('/static_url/(.*)', StaticUrlHandler),
+                ('/abs_static_url/(.*)', AbsoluteStaticUrlHandler),
+                ('/override_static_url/(.*)', OverrideStaticUrlHandler)]
+
+    def get_app_kwargs(self):
+        return dict(static_path=os.path.join(os.path.dirname(__file__),
+                                             'static'))
 
     def test_static_files(self):
         response = self.fetch('/robots.txt')
@@ -779,10 +790,10 @@ class StaticFileTest(AsyncHTTPTestCase):
         self.assertEqual(response2.code, 304)
         self.assertTrue('Content-Length' not in response2.headers)
         self.assertTrue('Last-Modified' not in response2.headers)
+wsgi_safe.append(StaticFileTest)
 
-
-class CustomStaticFileTest(AsyncHTTPTestCase):
-    def get_app(self):
+class CustomStaticFileTest(WebTestCase):
+    def get_handlers(self):
         class MyStaticFileHandler(StaticFileHandler):
             def get(self, path):
                 path = self.parse_url_path(path)
@@ -809,28 +820,33 @@ class CustomStaticFileTest(AsyncHTTPTestCase):
             def get(self, path):
                 self.write(self.static_url(path))
 
-        return Application([("/static_url/(.*)", StaticUrlHandler)],
-                           static_path="dummy",
-                           static_handler_class=MyStaticFileHandler)
+        self.static_handler_class = MyStaticFileHandler
+
+        return [("/static_url/(.*)", StaticUrlHandler)]
+
+    def get_app_kwargs(self):
+        return dict(static_path="dummy",
+                    static_handler_class=self.static_handler_class)
 
     def test_serve(self):
         response = self.fetch("/static/foo.42.txt")
         self.assertEqual(response.body, b("bar"))
 
     def test_static_url(self):
-        with ExpectLog(gen_log, "Could not open static file"):
+        with ExpectLog(gen_log, "Could not open static file", required=False):
             response = self.fetch("/static_url/foo.txt")
             self.assertEqual(response.body, b("/static/foo.42.txt"))
+wsgi_safe.append(CustomStaticFileTest)
 
 
-class NamedURLSpecGroupsTest(AsyncHTTPTestCase):
-    def get_app(self):
+class NamedURLSpecGroupsTest(WebTestCase):
+    def get_handlers(self):
         class EchoHandler(RequestHandler):
             def get(self, path):
                 self.write(path)
 
-        return Application([("/str/(?P<path>.*)", EchoHandler),
-                            (u"/unicode/(?P<path>.*)", EchoHandler)])
+        return [("/str/(?P<path>.*)", EchoHandler),
+                (u"/unicode/(?P<path>.*)", EchoHandler)]
 
     def test_named_urlspec_groups(self):
         response = self.fetch("/str/foo")
@@ -838,6 +854,7 @@ class NamedURLSpecGroupsTest(AsyncHTTPTestCase):
 
         response = self.fetch("/unicode/bar")
         self.assertEqual(response.body, b("bar"))
+wsgi_safe.append(NamedURLSpecGroupsTest)
 
 
 class ClearHeaderTest(SimpleHandlerTestCase):
@@ -852,7 +869,7 @@ class ClearHeaderTest(SimpleHandlerTestCase):
         response = self.fetch("/")
         self.assertTrue("h1" not in response.headers)
         self.assertEqual(response.headers["h2"], "bar")
-
+wsgi_safe.append(ClearHeaderTest)
 
 class Header304Test(SimpleHandlerTestCase):
     class Handler(RequestHandler):
@@ -872,3 +889,4 @@ class Header304Test(SimpleHandlerTestCase):
         self.assertTrue("Content-Language" not in response2.headers)
         # Not an entity header, but should not be added to 304s by chunking
         self.assertTrue("Transfer-Encoding" not in response2.headers)
+wsgi_safe.append(Header304Test)
index a143248070d6b3d182715f318076af9309973c6b..c327ba4d1858112bd4b1f90df9b80b1f4e56563f 100644 (file)
@@ -74,7 +74,14 @@ class WSGIConnectionTest(httpserver_test.HTTPConnectionTest):
         return WSGIContainer(validator(WSGIApplication(self.get_handlers())))
 
 
-class WSGIWebTest(web_test.WSGISafeWebTest):
-    def get_app(self):
-        self.app = WSGIApplication(self.get_handlers(), **self.get_app_kwargs())
-        return WSGIContainer(validator(self.app))
+def wrap_web_tests():
+    result = {}
+    for cls in web_test.wsgi_safe:
+        class WSGIWrappedTest(cls):
+            def get_app(self):
+                self.app = WSGIApplication(self.get_handlers(),
+                                           **self.get_app_kwargs())
+                return WSGIContainer(validator(self.app))
+        result["WSGIWrapped_" + cls.__name__] = WSGIWrappedTest
+    return result
+globals().update(wrap_web_tests())
index 98422f56d21a6e6d34c0be29be4df9d6157a54d4..9d5a02c9d4fb87fd240acea4dec60efbff213f82 100644 (file)
@@ -245,10 +245,11 @@ class WSGIContainer(object):
         headers = data["headers"]
         header_set = set(k.lower() for (k, v) in headers)
         body = escape.utf8(body)
-        if "content-length" not in header_set:
-            headers.append(("Content-Length", str(len(body))))
-        if "content-type" not in header_set:
-            headers.append(("Content-Type", "text/html; charset=UTF-8"))
+        if status_code != 304:
+            if "content-length" not in header_set:
+                headers.append(("Content-Length", str(len(body))))
+            if "content-type" not in header_set:
+                headers.append(("Content-Type", "text/html; charset=UTF-8"))
         if "server" not in header_set:
             headers.append(("Server", "TornadoServer/%s" % tornado.version))