]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
websocket_test: Remove most manual closes
authorBen Darnell <ben@bendarnell.com>
Mon, 10 Dec 2018 02:14:50 +0000 (21:14 -0500)
committerBen Darnell <ben@bendarnell.com>
Mon, 10 Dec 2018 02:14:50 +0000 (21:14 -0500)
At one time this was necessary to prevent spurious warnings at
shutdown, but not any more (and I intend to address warnings like this
with a more general solution).

tornado/test/websocket_test.py

index 704828e3ed28245d031ae69042dcb384e70314e2..d63a665af8105c20346d3b8e83e39b3531f39757 100644 (file)
@@ -39,10 +39,12 @@ except ImportError:
 class TestWebSocketHandler(WebSocketHandler):
     """Base class for testing handlers that exposes the on_close event.
 
-    This allows for deterministic cleanup of the associated socket.
+    This allows for tests to see the close code and reason on the
+    server side.
+
     """
 
-    def initialize(self, close_future, compression_options=None):
+    def initialize(self, close_future=None, compression_options=None):
         self.close_future = close_future
         self.compression_options = compression_options
 
@@ -50,7 +52,8 @@ class TestWebSocketHandler(WebSocketHandler):
         return self.compression_options
 
     def on_close(self):
-        self.close_future.set_result((self.close_code, self.close_reason))
+        if self.close_future is not None:
+            self.close_future.set_result((self.close_code, self.close_reason))
 
 
 class EchoHandler(TestWebSocketHandler):
@@ -125,10 +128,8 @@ class PathArgsHandler(TestWebSocketHandler):
 
 
 class CoroutineOnMessageHandler(TestWebSocketHandler):
-    def initialize(self, close_future, compression_options=None):
-        super(CoroutineOnMessageHandler, self).initialize(
-            close_future, compression_options
-        )
+    def initialize(self, **kwargs):
+        super(CoroutineOnMessageHandler, self).initialize(**kwargs)
         self.sleeping = 0
 
     @gen.coroutine
@@ -191,16 +192,6 @@ class WebSocketBaseTestCase(AsyncHTTPTestCase):
         )
         raise gen.Return(ws)
 
-    @gen.coroutine
-    def close(self, ws):
-        """Close a websocket connection and wait for the server side.
-
-        If we don't wait here, there are sometimes leak warnings in the
-        tests.
-        """
-        ws.close()
-        yield self.close_future
-
 
 class WebSocketTest(WebSocketBaseTestCase):
     def get_app(self):
@@ -296,7 +287,6 @@ class WebSocketTest(WebSocketBaseTestCase):
         yield ws.write_message("hello")
         response = yield ws.read_message()
         self.assertEqual(response, "hello")
-        yield self.close(ws)
 
     def test_websocket_callbacks(self):
         websocket_connect(
@@ -317,7 +307,6 @@ class WebSocketTest(WebSocketBaseTestCase):
         ws.write_message(b"hello \xe9", binary=True)
         response = yield ws.read_message()
         self.assertEqual(response, b"hello \xe9")
-        yield self.close(ws)
 
     @gen_test
     def test_unicode_message(self):
@@ -325,7 +314,6 @@ class WebSocketTest(WebSocketBaseTestCase):
         ws.write_message(u"hello \u00e9")
         response = yield ws.read_message()
         self.assertEqual(response, u"hello \u00e9")
-        yield self.close(ws)
 
     @gen_test
     def test_render_message(self):
@@ -333,7 +321,6 @@ class WebSocketTest(WebSocketBaseTestCase):
         ws.write_message("hello")
         response = yield ws.read_message()
         self.assertEqual(response, "<b>hello</b>")
-        yield self.close(ws)
 
     @gen_test
     def test_error_in_on_message(self):
@@ -342,7 +329,6 @@ class WebSocketTest(WebSocketBaseTestCase):
         with ExpectLog(app_log, "Uncaught exception"):
             response = yield ws.read_message()
         self.assertIs(response, None)
-        yield self.close(ws)
 
     @gen_test
     def test_websocket_http_fail(self):
@@ -372,7 +358,6 @@ class WebSocketTest(WebSocketBaseTestCase):
         ws.write_message("world")
         # Close the underlying stream.
         ws.stream.close()
-        yield self.close_future
 
     @gen_test
     def test_websocket_headers(self):
@@ -385,7 +370,6 @@ class WebSocketTest(WebSocketBaseTestCase):
         )
         response = yield ws.read_message()
         self.assertEqual(response, "hello")
-        yield self.close(ws)
 
     @gen_test
     def test_websocket_header_echo(self):
@@ -402,7 +386,6 @@ class WebSocketTest(WebSocketBaseTestCase):
         self.assertEqual(
             ws.headers.get("X-Extra-Response-Header"), "Extra-Response-Value"
         )
-        yield self.close(ws)
 
     @gen_test
     def test_server_close_reason(self):
@@ -472,7 +455,6 @@ class WebSocketTest(WebSocketBaseTestCase):
         ws.write_message("hello")
         response = yield ws.read_message()
         self.assertEqual(response, "hello")
-        yield self.close(ws)
 
     @gen_test
     def test_check_origin_valid_with_path(self):
@@ -485,7 +467,6 @@ class WebSocketTest(WebSocketBaseTestCase):
         ws.write_message("hello")
         response = yield ws.read_message()
         self.assertEqual(response, "hello")
-        yield self.close(ws)
 
     @gen_test
     def test_check_origin_invalid_partial_url(self):
@@ -534,7 +515,6 @@ class WebSocketTest(WebSocketBaseTestCase):
         self.assertEqual(ws.selected_subprotocol, "goodproto")
         res = yield ws.read_message()
         self.assertEqual(res, "subprotocol=goodproto")
-        yield self.close(ws)
 
     @gen_test
     def test_subprotocols_not_offered(self):
@@ -542,7 +522,6 @@ class WebSocketTest(WebSocketBaseTestCase):
         self.assertIs(ws.selected_subprotocol, None)
         res = yield ws.read_message()
         self.assertEqual(res, "subprotocol=None")
-        yield self.close(ws)
 
     @gen_test
     def test_open_coroutine(self):
@@ -552,12 +531,11 @@ class WebSocketTest(WebSocketBaseTestCase):
         self.message_sent.set()
         res = yield ws.read_message()
         self.assertEqual(res, "ok")
-        yield self.close(ws)
 
 
 class NativeCoroutineOnMessageHandler(TestWebSocketHandler):
-    def initialize(self, close_future, compression_options=None):
-        super().initialize(close_future, compression_options)
+    def initialize(self, **kwargs):
+        super().initialize(**kwargs)
         self.sleeping = 0
 
     async def on_message(self, message):
@@ -571,16 +549,7 @@ class NativeCoroutineOnMessageHandler(TestWebSocketHandler):
 
 class WebSocketNativeCoroutineTest(WebSocketBaseTestCase):
     def get_app(self):
-        self.close_future = Future()  # type: Future[None]
-        return Application(
-            [
-                (
-                    "/native",
-                    NativeCoroutineOnMessageHandler,
-                    dict(close_future=self.close_future),
-                )
-            ]
-        )
+        return Application([("/native", NativeCoroutineOnMessageHandler)])
 
     @gen_test
     def test_native_coroutine(self):
@@ -598,8 +567,6 @@ class CompressionTestMixin(object):
     MESSAGE = "Hello world. Testing 123 123"
 
     def get_app(self):
-        self.close_future = Future()  # type: Future[None]
-
         class LimitedHandler(TestWebSocketHandler):
             @property
             def max_message_size(self):
@@ -613,18 +580,12 @@ class CompressionTestMixin(object):
                 (
                     "/echo",
                     EchoHandler,
-                    dict(
-                        close_future=self.close_future,
-                        compression_options=self.get_server_compression_options(),
-                    ),
+                    dict(compression_options=self.get_server_compression_options()),
                 ),
                 (
                     "/limited",
                     LimitedHandler,
-                    dict(
-                        close_future=self.close_future,
-                        compression_options=self.get_server_compression_options(),
-                    ),
+                    dict(compression_options=self.get_server_compression_options()),
                 ),
             ]
         )
@@ -649,7 +610,6 @@ class CompressionTestMixin(object):
         self.assertEqual(ws.protocol._message_bytes_out, len(self.MESSAGE) * 3)
         self.assertEqual(ws.protocol._message_bytes_in, len(self.MESSAGE) * 3)
         self.verify_wire_bytes(ws.protocol._wire_bytes_in, ws.protocol._wire_bytes_out)
-        yield self.close(ws)
 
     @gen_test
     def test_size_limit(self):
@@ -665,7 +625,6 @@ class CompressionTestMixin(object):
         ws.write_message("a" * 2048)
         response = yield ws.read_message()
         self.assertIsNone(response)
-        yield self.close(ws)
 
 
 class UncompressedTestMixin(CompressionTestMixin):
@@ -743,11 +702,7 @@ class ServerPeriodicPingTest(WebSocketBaseTestCase):
             def on_pong(self, data):
                 self.write_message("got pong")
 
-        self.close_future = Future()  # type: Future[None]
-        return Application(
-            [("/", PingHandler, dict(close_future=self.close_future))],
-            websocket_ping_interval=0.01,
-        )
+        return Application([("/", PingHandler)], websocket_ping_interval=0.01)
 
     @gen_test
     def test_server_ping(self):
@@ -755,7 +710,6 @@ class ServerPeriodicPingTest(WebSocketBaseTestCase):
         for i in range(3):
             response = yield ws.read_message()
             self.assertEqual(response, "got pong")
-        yield self.close(ws)
         # TODO: test that the connection gets closed if ping responses stop.
 
 
@@ -765,8 +719,7 @@ class ClientPeriodicPingTest(WebSocketBaseTestCase):
             def on_ping(self, data):
                 self.write_message("got ping")
 
-        self.close_future = Future()  # type: Future[None]
-        return Application([("/", PingHandler, dict(close_future=self.close_future))])
+        return Application([("/", PingHandler)])
 
     @gen_test
     def test_client_ping(self):
@@ -774,7 +727,6 @@ class ClientPeriodicPingTest(WebSocketBaseTestCase):
         for i in range(3):
             response = yield ws.read_message()
             self.assertEqual(response, "got ping")
-        yield self.close(ws)
         # TODO: test that the connection gets closed if ping responses stop.
 
 
@@ -784,8 +736,7 @@ class ManualPingTest(WebSocketBaseTestCase):
             def on_ping(self, data):
                 self.write_message(data, binary=isinstance(data, bytes))
 
-        self.close_future = Future()  # type: Future[None]
-        return Application([("/", PingHandler, dict(close_future=self.close_future))])
+        return Application([("/", PingHandler)])
 
     @gen_test
     def test_manual_ping(self):
@@ -801,16 +752,11 @@ class ManualPingTest(WebSocketBaseTestCase):
         ws.ping(b"binary hello")
         resp = yield ws.read_message()
         self.assertEqual(resp, b"binary hello")
-        yield self.close(ws)
 
 
 class MaxMessageSizeTest(WebSocketBaseTestCase):
     def get_app(self):
-        self.close_future = Future()  # type: Future[None]
-        return Application(
-            [("/", EchoHandler, dict(close_future=self.close_future))],
-            websocket_max_message_size=1024,
-        )
+        return Application([("/", EchoHandler)], websocket_max_message_size=1024)
 
     @gen_test
     def test_large_message(self):