]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
test: Adopt abstract_base_test in simple_httpclient_test.py
authorBen Darnell <ben@bendarnell.com>
Thu, 11 Jul 2024 18:49:17 +0000 (14:49 -0400)
committerBen Darnell <ben@bendarnell.com>
Thu, 11 Jul 2024 18:49:17 +0000 (14:49 -0400)
tornado/test/simple_httpclient_test.py

index 3bb66187648cabce2091a9bc02e7919a4482da66..07f04ae938cde4364913455e88926893faf2b974 100644 (file)
@@ -11,7 +11,8 @@ import typing  # noqa: F401
 
 from tornado.escape import to_unicode, utf8
 from tornado import gen, version
-from tornado.httpclient import AsyncHTTPClient
+from tornado.httpclient import AsyncHTTPClient, HTTPResponse
+from tornado.httpserver import HTTPServer
 from tornado.httputil import HTTPHeaders, ResponseStartLine
 from tornado.ioloop import IOLoop
 from tornado.iostream import UnsatisfiableReadError
@@ -38,7 +39,12 @@ from tornado.testing import (
     ExpectLog,
     gen_test,
 )
-from tornado.test.util import skipOnTravis, skipIfNoIPv6, refusing_port
+from tornado.test.util import (
+    abstract_base_test,
+    skipOnTravis,
+    skipIfNoIPv6,
+    refusing_port,
+)
 from tornado.web import RequestHandler, Application, url, stream_request_body
 
 
@@ -142,11 +148,31 @@ class RespondInPrepareHandler(RequestHandler):
         self.finish("forbidden")
 
 
-class SimpleHTTPClientTestMixin:
+@abstract_base_test
+class SimpleHTTPClientTestMixin(AsyncTestCase):
+    # See comments on TestIOStreamWebMixin
+    def get_http_port(self) -> int:
+        raise NotImplementedError()
+
+    def fetch(
+        self, path: str, raise_error: bool = False, **kwargs: typing.Any
+    ) -> HTTPResponse:
+        # To be filled in by mixing in AsyncHTTPTestCase or AsyncHTTPSTestCase
+        raise NotImplementedError()
+
+    def get_url(self, path: str) -> str:
+        raise NotImplementedError()
+
+    def get_protocol(self) -> str:
+        raise NotImplementedError()
+
+    def get_http_server(self) -> HTTPServer:
+        raise NotImplementedError()
+
     def create_client(self, **kwargs):
         raise NotImplementedError()
 
-    def get_app(self: typing.Any):
+    def mixin_get_app(self):
         # callable objects to finish pending /trigger requests
         self.triggers = (
             collections.deque()
@@ -177,7 +203,7 @@ class SimpleHTTPClientTestMixin:
             gzip=True,
         )
 
-    def test_singleton(self: typing.Any):
+    def test_singleton(self):
         # Class "constructor" reuses objects on the same IOLoop
         self.assertIs(SimpleAsyncHTTPClient(), SimpleAsyncHTTPClient())
         # unless force_instance is used
@@ -195,7 +221,7 @@ class SimpleHTTPClientTestMixin:
             client2 = io_loop2.run_sync(make_client)
             self.assertIsNot(client1, client2)
 
-    def test_connection_limit(self: typing.Any):
+    def test_connection_limit(self):
         with closing(self.create_client(max_clients=2)) as client:
             self.assertEqual(client.max_clients, 2)
             seen = []
@@ -226,13 +252,13 @@ class SimpleHTTPClientTestMixin:
             self.assertEqual(len(self.triggers), 0)
 
     @gen_test
-    def test_redirect_connection_limit(self: typing.Any):
+    def test_redirect_connection_limit(self):
         # following redirects should not consume additional connections
         with closing(self.create_client(max_clients=1)) as client:
             response = yield client.fetch(self.get_url("/countdown/3"), max_redirects=3)
             response.rethrow()
 
-    def test_max_redirects(self: typing.Any):
+    def test_max_redirects(self):
         response = self.fetch("/countdown/5", max_redirects=3)
         self.assertEqual(302, response.code)
         # We requested 5, followed three redirects for 4, 3, 2, then the last
@@ -241,19 +267,19 @@ class SimpleHTTPClientTestMixin:
         self.assertTrue(response.effective_url.endswith("/countdown/2"))
         self.assertTrue(response.headers["Location"].endswith("/countdown/1"))
 
-    def test_header_reuse(self: typing.Any):
+    def test_header_reuse(self):
         # Apps may reuse a headers object if they are only passing in constant
         # headers like user-agent.  The header object should not be modified.
         headers = HTTPHeaders({"User-Agent": "Foo"})
         self.fetch("/hello", headers=headers)
         self.assertEqual(list(headers.get_all()), [("User-Agent", "Foo")])
 
-    def test_default_user_agent(self: typing.Any):
+    def test_default_user_agent(self):
         response = self.fetch("/user_agent", method="GET")
         self.assertEqual(200, response.code)
         self.assertEqual(response.body.decode(), f"Tornado/{version}")
 
-    def test_see_other_redirect(self: typing.Any):
+    def test_see_other_redirect(self):
         for code in (302, 303):
             response = self.fetch("/see_other_post", method="POST", body="%d" % code)
             self.assertEqual(200, response.code)
@@ -264,7 +290,7 @@ class SimpleHTTPClientTestMixin:
 
     @skipOnTravis
     @gen_test
-    def test_connect_timeout(self: typing.Any):
+    def test_connect_timeout(self):
         timeout = 0.1
 
         cleanup_event = Event()
@@ -292,7 +318,7 @@ class SimpleHTTPClientTestMixin:
         yield gen.sleep(0.2)
 
     @skipOnTravis
-    def test_request_timeout(self: typing.Any):
+    def test_request_timeout(self):
         timeout = 0.1
         if os.name == "nt":
             timeout = 0.5
@@ -304,10 +330,10 @@ class SimpleHTTPClientTestMixin:
         self.io_loop.run_sync(lambda: gen.sleep(0))
 
     @skipIfNoIPv6
-    def test_ipv6(self: typing.Any):
+    def test_ipv6(self):
         [sock] = bind_sockets(0, "::1", family=socket.AF_INET6)
         port = sock.getsockname()[1]
-        self.http_server.add_socket(sock)
+        self.get_http_server().add_socket(sock)
         url = "%s://[::1]:%d/hello" % (self.get_protocol(), port)
 
         # ipv6 is currently enabled by default but can be disabled
@@ -317,7 +343,7 @@ class SimpleHTTPClientTestMixin:
         response = self.fetch(url)
         self.assertEqual(response.body, b"Hello world!")
 
-    def test_multiple_content_length_accepted(self: typing.Any):
+    def test_multiple_content_length_accepted(self):
         response = self.fetch("/content_length?value=2,2")
         self.assertEqual(response.body, b"ok")
         response = self.fetch("/content_length?value=2,%202,2")
@@ -331,20 +357,20 @@ class SimpleHTTPClientTestMixin:
             with self.assertRaises(HTTPStreamClosedError):
                 self.fetch("/content_length?value=2,%202,3", raise_error=True)
 
-    def test_head_request(self: typing.Any):
+    def test_head_request(self):
         response = self.fetch("/head", method="HEAD")
         self.assertEqual(response.code, 200)
         self.assertEqual(response.headers["content-length"], "7")
         self.assertFalse(response.body)
 
-    def test_options_request(self: typing.Any):
+    def test_options_request(self):
         response = self.fetch("/options", method="OPTIONS")
         self.assertEqual(response.code, 200)
         self.assertEqual(response.headers["content-length"], "2")
         self.assertEqual(response.headers["access-control-allow-origin"], "*")
         self.assertEqual(response.body, b"ok")
 
-    def test_no_content(self: typing.Any):
+    def test_no_content(self):
         response = self.fetch("/no_content")
         self.assertEqual(response.code, 204)
         # 204 status shouldn't have a content-length
@@ -353,7 +379,7 @@ class SimpleHTTPClientTestMixin:
         # in HTTP204NoContentTestCase.
         self.assertNotIn("Content-Length", response.headers)
 
-    def test_host_header(self: typing.Any):
+    def test_host_header(self):
         host_re = re.compile(b"^127.0.0.1:[0-9]+$")
         response = self.fetch("/host_echo")
         self.assertTrue(host_re.match(response.body))
@@ -362,7 +388,7 @@ class SimpleHTTPClientTestMixin:
         response = self.fetch(url)
         self.assertTrue(host_re.match(response.body), response.body)
 
-    def test_connection_refused(self: typing.Any):
+    def test_connection_refused(self):
         cleanup_func, port = refusing_port()
         self.addCleanup(cleanup_func)
         with ExpectLog(gen_log, ".*", required=False):
@@ -382,7 +408,7 @@ class SimpleHTTPClientTestMixin:
             expected_message = os.strerror(errno.ECONNREFUSED)
             self.assertTrue(expected_message in str(cm.exception), cm.exception)
 
-    def test_queue_timeout(self: typing.Any):
+    def test_queue_timeout(self):
         with closing(self.create_client(max_clients=1)) as client:
             # Wait for the trigger request to block, not complete.
             fut1 = client.fetch(self.get_url("/trigger"), request_timeout=10)
@@ -398,7 +424,7 @@ class SimpleHTTPClientTestMixin:
             self.triggers.popleft()()
             self.io_loop.run_sync(lambda: fut1)
 
-    def test_no_content_length(self: typing.Any):
+    def test_no_content_length(self):
         response = self.fetch("/no_content_length")
         if response.body == b"HTTP/1 required":
             self.skipTest("requires HTTP/1.x")
@@ -415,14 +441,14 @@ class SimpleHTTPClientTestMixin:
         yield gen.moment
         yield write(b"5678")
 
-    def test_sync_body_producer_chunked(self: typing.Any):
+    def test_sync_body_producer_chunked(self):
         response = self.fetch(
             "/echo_post", method="POST", body_producer=self.sync_body_producer
         )
         response.rethrow()
         self.assertEqual(response.body, b"12345678")
 
-    def test_sync_body_producer_content_length(self: typing.Any):
+    def test_sync_body_producer_content_length(self):
         response = self.fetch(
             "/echo_post",
             method="POST",
@@ -432,14 +458,14 @@ class SimpleHTTPClientTestMixin:
         response.rethrow()
         self.assertEqual(response.body, b"12345678")
 
-    def test_async_body_producer_chunked(self: typing.Any):
+    def test_async_body_producer_chunked(self):
         response = self.fetch(
             "/echo_post", method="POST", body_producer=self.async_body_producer
         )
         response.rethrow()
         self.assertEqual(response.body, b"12345678")
 
-    def test_async_body_producer_content_length(self: typing.Any):
+    def test_async_body_producer_content_length(self):
         response = self.fetch(
             "/echo_post",
             method="POST",
@@ -449,7 +475,7 @@ class SimpleHTTPClientTestMixin:
         response.rethrow()
         self.assertEqual(response.body, b"12345678")
 
-    def test_native_body_producer_chunked(self: typing.Any):
+    def test_native_body_producer_chunked(self):
         async def body_producer(write):
             await write(b"1234")
             import asyncio
@@ -461,7 +487,7 @@ class SimpleHTTPClientTestMixin:
         response.rethrow()
         self.assertEqual(response.body, b"12345678")
 
-    def test_native_body_producer_content_length(self: typing.Any):
+    def test_native_body_producer_content_length(self):
         async def body_producer(write):
             await write(b"1234")
             import asyncio
@@ -478,13 +504,13 @@ class SimpleHTTPClientTestMixin:
         response.rethrow()
         self.assertEqual(response.body, b"12345678")
 
-    def test_100_continue(self: typing.Any):
+    def test_100_continue(self):
         response = self.fetch(
             "/echo_post", method="POST", body=b"1234", expect_100_continue=True
         )
         self.assertEqual(response.body, b"1234")
 
-    def test_100_continue_early_response(self: typing.Any):
+    def test_100_continue_early_response(self):
         def body_producer(write):
             raise Exception("should not be called")
 
@@ -496,7 +522,7 @@ class SimpleHTTPClientTestMixin:
         )
         self.assertEqual(response.code, 403)
 
-    def test_streaming_follow_redirects(self: typing.Any):
+    def test_streaming_follow_redirects(self):
         # When following redirects, header and streaming callbacks
         # should only be called for the final result.
         # TODO(bdarnell): this test belongs in httpclient_test instead of
@@ -517,20 +543,26 @@ class SimpleHTTPClientTestMixin:
         self.assertEqual(num_start_lines, 1)
 
 
-class SimpleHTTPClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPTestCase):
+class SimpleHTTPClientTestCase(AsyncHTTPTestCase, SimpleHTTPClientTestMixin):
     def setUp(self):
         super().setUp()
         self.http_client = self.create_client()
 
+    def get_app(self):
+        return self.mixin_get_app()
+
     def create_client(self, **kwargs):
         return SimpleAsyncHTTPClient(force_instance=True, **kwargs)
 
 
-class SimpleHTTPSClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPSTestCase):
+class SimpleHTTPSClientTestCase(AsyncHTTPSTestCase, SimpleHTTPClientTestMixin):
     def setUp(self):
         super().setUp()
         self.http_client = self.create_client()
 
+    def get_app(self):
+        return self.mixin_get_app()
+
     def create_client(self, **kwargs):
         return SimpleAsyncHTTPClient(
             force_instance=True, defaults=dict(validate_cert=False), **kwargs