]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add testing.AsyncSSLTestCase 531/head
authorAlek Storm <alek.storm@gmail.com>
Fri, 8 Jun 2012 03:23:04 +0000 (23:23 -0400)
committerAlek Storm <alek.storm@gmail.com>
Fri, 8 Jun 2012 03:23:04 +0000 (23:23 -0400)
Allow subclasses of AsyncHTTPTestCase to provide their own http client
and server implementations.

tornado/test/httpclient_test.py
tornado/test/httpserver_test.py
tornado/testing.py

index 9ec967999dabdae1103194df472111f00195e835..f1ceed9819ddbdf7ede39293bba3e53ace4a9d13 100644 (file)
@@ -60,10 +60,6 @@ class EchoPostHandler(RequestHandler):
 
 
 class HTTPClientCommonTestCase(AsyncHTTPTestCase, LogTrapTestCase):
-    def get_http_client(self):
-        """Returns AsyncHTTPClient instance.  May be overridden in subclass."""
-        return AsyncHTTPClient(io_loop=self.io_loop)
-
     def get_app(self):
         return Application([
             url("/hello", HelloWorldHandler),
@@ -74,11 +70,6 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase, LogTrapTestCase):
             url("/echopost", EchoPostHandler),
             ], gzip=True)
 
-    def setUp(self):
-        super(HTTPClientCommonTestCase, self).setUp()
-        # replace the client defined in the parent class
-        self.http_client = self.get_http_client()
-
     def test_hello_world(self):
         response = self.fetch("/hello")
         self.assertEqual(response.code, 200)
index 3b49a5026acb01d45234ca9976d2954e7e554f0f..f761fd94976f78eb5402e7b7681ab4eeb46f15b0 100644 (file)
@@ -8,7 +8,7 @@ from tornado.httpserver import HTTPServer
 from tornado.httputil import HTTPHeaders
 from tornado.iostream import IOStream
 from tornado.simple_httpclient import SimpleAsyncHTTPClient
-from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, AsyncTestCase
+from tornado.testing import AsyncHTTPTestCase, AsyncSSLTestCase, AsyncTestCase, LogTrapTestCase
 from tornado.util import b, bytes_type
 from tornado.web import Application, RequestHandler
 import os
@@ -45,38 +45,11 @@ class HelloWorldRequestHandler(RequestHandler):
         self.finish("Got %d bytes in POST" % len(self.request.body))
 
 
-class BaseSSLTest(AsyncHTTPTestCase, LogTrapTestCase):
-    def get_ssl_version(self):
-        raise NotImplementedError()
-
-    def setUp(self):
-        super(BaseSSLTest, self).setUp()
-        # Replace the client defined in the parent class.
-        # Some versions of libcurl have deadlock bugs with ssl,
-        # so always run these tests with SimpleAsyncHTTPClient.
-        self.http_client = SimpleAsyncHTTPClient(io_loop=self.io_loop,
-                                                 force_instance=True)
-
+class BaseSSLTest(AsyncSSLTestCase, LogTrapTestCase):
     def get_app(self):
         return Application([('/', HelloWorldRequestHandler,
                              dict(protocol="https"))])
 
-    def get_httpserver_options(self):
-        # Testing keys were generated with:
-        # openssl req -new -keyout tornado/test/test.key -out tornado/test/test.crt -nodes -days 3650 -x509
-        test_dir = os.path.dirname(__file__)
-        return dict(ssl_options=dict(
-                certfile=os.path.join(test_dir, 'test.crt'),
-                keyfile=os.path.join(test_dir, 'test.key'),
-                ssl_version=self.get_ssl_version()))
-
-    def fetch(self, path, **kwargs):
-        self.http_client.fetch(self.get_url(path).replace('http', 'https'),
-                               self.stop,
-                               validate_cert=False,
-                               **kwargs)
-        return self.wait()
-
 
 class SSLTestMixin(object):
     def test_ssl(self):
@@ -119,38 +92,36 @@ class TLSv1Test(BaseSSLTest, SSLTestMixin):
     def get_ssl_version(self):
         return ssl.PROTOCOL_TLSv1
 
-if hasattr(ssl, 'PROTOCOL_SSLv2'):
-    class SSLv2Test(BaseSSLTest):
-        def get_ssl_version(self):
-            return ssl.PROTOCOL_SSLv2
-
-        def test_sslv2_fail(self):
-            # This is really more of a client test, but run it here since
-            # we've got all the other ssl version tests here.
-            # Clients should have SSLv2 disabled by default.
-            try:
-                # The server simply closes the connection when it gets
-                # an SSLv2 ClientHello packet.
-                # request_timeout is needed here because on some platforms
-                # (cygwin, but not native windows python), the close is not
-                # detected promptly.
-                response = self.fetch('/', request_timeout=1)
-            except ssl.SSLError:
-                # In some python/ssl builds the PROTOCOL_SSLv2 constant
-                # exists but SSLv2 support is still compiled out, which
-                # would result in an SSLError here (details vary depending
-                # on python version).  The important thing is that
-                # SSLv2 request's don't succeed, so we can just ignore
-                # the errors here.
-                return
-            self.assertEqual(response.code, 599)
+
+class SSLv2Test(BaseSSLTest):
+    def get_ssl_version(self):
+        return ssl.PROTOCOL_SSLv2
+
+    def test_sslv2_fail(self):
+        # This is really more of a client test, but run it here since
+        # we've got all the other ssl version tests here.
+        # Clients should have SSLv2 disabled by default.
+        try:
+            # The server simply closes the connection when it gets
+            # an SSLv2 ClientHello packet.
+            # request_timeout is needed here because on some platforms
+            # (cygwin, but not native windows python), the close is not
+            # detected promptly.
+            response = self.fetch('/', request_timeout=1)
+        except ssl.SSLError:
+            # In some python/ssl builds the PROTOCOL_SSLv2 constant
+            # exists but SSLv2 support is still compiled out, which
+            # would result in an SSLError here (details vary depending
+            # on python version).  The important thing is that
+            # SSLv2 request's don't succeed, so we can just ignore
+            # the errors here.
+            return
+        self.assertEqual(response.code, 599)
 
 if ssl is None:
     del BaseSSLTest
     del SSLv23Test
-    del SSLv3Test
-    del TLSv1Test
-elif getattr(ssl, 'OPENSSL_VERSION_INFO', (0, 0)) < (1, 0):
+if getattr(ssl, 'OPENSSL_VERSION_INFO', (0, 0)) < (1, 0):
     # In pre-1.0 versions of openssl, SSLv23 clients always send SSLv2
     # ClientHello messages, which are rejected by SSLv3 and TLSv1
     # servers.  Note that while the OPENSSL_VERSION_INFO was formally
@@ -158,6 +129,8 @@ elif getattr(ssl, 'OPENSSL_VERSION_INFO', (0, 0)) < (1, 0):
     # python 2.7
     del SSLv3Test
     del TLSv1Test
+if not hasattr(ssl, 'PROTOCOL_SSLv2'):
+    del SSLv2Test
 
 
 class MultipartTestHandler(RequestHandler):
index fccdb8610c601c18d7848aec19b75047132756ed..b31ec4465bc004c9e357a92b17d81daae9a5ff4e 100644 (file)
@@ -24,6 +24,7 @@ from cStringIO import StringIO
 try:
     from tornado.httpclient import AsyncHTTPClient
     from tornado.httpserver import HTTPServer
+    from tornado.simple_httpclient import SimpleAsyncHTTPClient
     from tornado.ioloop import IOLoop
 except ImportError:
     # These modules are not importable on app engine.  Parts of this module
@@ -31,10 +32,12 @@ except ImportError:
     AsyncHTTPClient = None
     HTTPServer = None
     IOLoop = None
+    SimpleAsyncHTTPClient = None
 from tornado.stack_context import StackContext, NullContext
 from tornado.util import raise_exc_info
 import contextlib
 import logging
+import os
 import signal
 import sys
 import time
@@ -232,12 +235,19 @@ class AsyncHTTPTestCase(AsyncTestCase):
         super(AsyncHTTPTestCase, self).setUp()
         self.__port = None
 
-        self.http_client = AsyncHTTPClient(io_loop=self.io_loop)
+        self.http_client = self.get_http_client()
         self._app = self.get_app()
-        self.http_server = HTTPServer(self._app, io_loop=self.io_loop,
-                                      **self.get_httpserver_options())
+        self.http_server = self.get_http_server()
         self.http_server.listen(self.get_http_port(), address="127.0.0.1")
 
+    def get_http_client(self):
+        return AsyncHTTPClient(io_loop=self.io_loop)
+
+    def get_http_server(self):
+        return HTTPServer(self._app, io_loop=self.io_loop,
+                          **self.get_httpserver_options())
+
+
     def get_app(self):
         """Should be overridden by subclasses to return a
         tornado.web.Application or other HTTPServer callback.
@@ -257,12 +267,12 @@ class AsyncHTTPTestCase(AsyncTestCase):
 
     def get_httpserver_options(self):
         """May be overridden by subclasses to return additional
-        keyword arguments for HTTPServer.
+        keyword arguments for the server.
         """
         return {}
 
     def get_http_port(self):
-        """Returns the port used by the HTTPServer.
+        """Returns the port used by the server.
 
         A new port is chosen for each test.
         """
@@ -270,9 +280,13 @@ class AsyncHTTPTestCase(AsyncTestCase):
             self.__port = get_unused_port()
         return self.__port
 
+    def get_protocol(self):
+        return 'http'
+
     def get_url(self, path):
         """Returns an absolute url for the given path on the test server."""
-        return 'http://localhost:%s%s' % (self.get_http_port(), path)
+        return '%s://localhost:%s%s' % (self.get_protocol(),
+                                        self.get_http_port(), path)
 
     def tearDown(self):
         self.http_server.stop()
@@ -280,6 +294,35 @@ class AsyncHTTPTestCase(AsyncTestCase):
         super(AsyncHTTPTestCase, self).tearDown()
 
 
+class AsyncSSLTestCase(AsyncHTTPTestCase):
+    def get_ssl_version(self):
+        raise NotImplementedError()
+
+    def get_http_client(self):
+        # Some versions of libcurl have deadlock bugs with ssl,
+        # so always run these tests with SimpleAsyncHTTPClient.
+        return SimpleAsyncHTTPClient(io_loop=self.io_loop, force_instance=True)
+
+    def get_httpserver_options(self):
+        return dict(ssl_options=self.get_ssl_options())
+
+    def get_ssl_options(self):
+        # Testing keys were generated with:
+        # openssl req -new -keyout tornado/test/test.key -out tornado/test/test.crt -nodes -days 3650 -x509
+        module_dir = os.path.dirname(__file__)
+        return dict(
+                certfile=os.path.join(module_dir, 'test', 'test.crt'),
+                keyfile=os.path.join(module_dir, 'test', 'test.key'),
+                ssl_version=self.get_ssl_version())
+
+    def get_protocol(self):
+        return 'https'
+
+    def fetch(self, path, **kwargs):
+        return AsyncHTTPTestCase.fetch(self, path, validate_cert=False,
+                   **kwargs)
+
+
 class LogTrapTestCase(unittest.TestCase):
     """A test case that captures and discards all logging output
     if the test passes.