]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
websocket: Add warning if client connection isn't closed cleanly 3264/head
authorBen Darnell <ben@bendarnell.com>
Sun, 7 May 2023 21:03:33 +0000 (17:03 -0400)
committerBen Darnell <ben@bendarnell.com>
Mon, 8 May 2023 01:22:07 +0000 (21:22 -0400)
This gives a warning that is not dependent on GC for the issue
in #3257. This new warning covers all websocket client connections,
while the previous GC-dependent warning only affected those with
ping_interval set. This unfortunately introduces an effective
requirement to close all websocket clients explicitly for those
who are strict about warnings.

tornado/test/websocket_test.py
tornado/websocket.py

index 0a29ae64609e95bd6c6b57d248908ffb699cabc5..4d39f370467e8ced393a66e682d990314de74f02 100644 (file)
@@ -1,4 +1,5 @@
 import asyncio
+import contextlib
 import functools
 import socket
 import traceback
@@ -213,11 +214,21 @@ class NoDelayHandler(TestWebSocketHandler):
 
 
 class WebSocketBaseTestCase(AsyncHTTPTestCase):
+    def setUp(self):
+        super().setUp()
+        self.conns_to_close = []
+
+    def tearDown(self):
+        for conn in self.conns_to_close:
+            conn.close()
+        super().tearDown()
+
     @gen.coroutine
     def ws_connect(self, path, **kwargs):
         ws = yield websocket_connect(
             "ws://127.0.0.1:%d%s" % (self.get_http_port(), path), **kwargs
         )
+        self.conns_to_close.append(ws)
         raise gen.Return(ws)
 
 
@@ -397,39 +408,49 @@ class WebSocketTest(WebSocketBaseTestCase):
 
     @gen_test
     def test_websocket_close_buffered_data(self):
-        ws = yield websocket_connect("ws://127.0.0.1:%d/echo" % self.get_http_port())
-        ws.write_message("hello")
-        ws.write_message("world")
-        # Close the underlying stream.
-        ws.stream.close()
+        with contextlib.closing(
+            (yield websocket_connect("ws://127.0.0.1:%d/echo" % self.get_http_port()))
+        ) as ws:
+            ws.write_message("hello")
+            ws.write_message("world")
+            # Close the underlying stream.
+            ws.stream.close()
 
     @gen_test
     def test_websocket_headers(self):
         # Ensure that arbitrary headers can be passed through websocket_connect.
-        ws = yield websocket_connect(
-            HTTPRequest(
-                "ws://127.0.0.1:%d/header" % self.get_http_port(),
-                headers={"X-Test": "hello"},
+        with contextlib.closing(
+            (
+                yield websocket_connect(
+                    HTTPRequest(
+                        "ws://127.0.0.1:%d/header" % self.get_http_port(),
+                        headers={"X-Test": "hello"},
+                    )
+                )
             )
-        )
-        response = yield ws.read_message()
-        self.assertEqual(response, "hello")
+        ) as ws:
+            response = yield ws.read_message()
+            self.assertEqual(response, "hello")
 
     @gen_test
     def test_websocket_header_echo(self):
         # Ensure that headers can be returned in the response.
         # Specifically, that arbitrary headers passed through websocket_connect
         # can be returned.
-        ws = yield websocket_connect(
-            HTTPRequest(
-                "ws://127.0.0.1:%d/header_echo" % self.get_http_port(),
-                headers={"X-Test-Hello": "hello"},
+        with contextlib.closing(
+            (
+                yield websocket_connect(
+                    HTTPRequest(
+                        "ws://127.0.0.1:%d/header_echo" % self.get_http_port(),
+                        headers={"X-Test-Hello": "hello"},
+                    )
+                )
+            )
+        ) as ws:
+            self.assertEqual(ws.headers.get("X-Test-Hello"), "hello")
+            self.assertEqual(
+                ws.headers.get("X-Extra-Response-Header"), "Extra-Response-Value"
             )
-        )
-        self.assertEqual(ws.headers.get("X-Test-Hello"), "hello")
-        self.assertEqual(
-            ws.headers.get("X-Extra-Response-Header"), "Extra-Response-Value"
-        )
 
     @gen_test
     def test_server_close_reason(self):
@@ -495,10 +516,12 @@ class WebSocketTest(WebSocketBaseTestCase):
         url = "ws://127.0.0.1:%d/echo" % port
         headers = {"Origin": "http://127.0.0.1:%d" % port}
 
-        ws = yield websocket_connect(HTTPRequest(url, headers=headers))
-        ws.write_message("hello")
-        response = yield ws.read_message()
-        self.assertEqual(response, "hello")
+        with contextlib.closing(
+            (yield websocket_connect(HTTPRequest(url, headers=headers)))
+        ) as ws:
+            ws.write_message("hello")
+            response = yield ws.read_message()
+            self.assertEqual(response, "hello")
 
     @gen_test
     def test_check_origin_valid_with_path(self):
@@ -507,10 +530,12 @@ class WebSocketTest(WebSocketBaseTestCase):
         url = "ws://127.0.0.1:%d/echo" % port
         headers = {"Origin": "http://127.0.0.1:%d/something" % port}
 
-        ws = yield websocket_connect(HTTPRequest(url, headers=headers))
-        ws.write_message("hello")
-        response = yield ws.read_message()
-        self.assertEqual(response, "hello")
+        with contextlib.closing(
+            (yield websocket_connect(HTTPRequest(url, headers=headers)))
+        ) as ws:
+            ws.write_message("hello")
+            response = yield ws.read_message()
+            self.assertEqual(response, "hello")
 
     @gen_test
     def test_check_origin_invalid_partial_url(self):
index d0abd42595be07ab7ead5e9350857017e1fd42cf..165cc316d601625ab657c6be7bb6c8ff94d507ce 100644 (file)
@@ -20,6 +20,7 @@ import sys
 import struct
 import tornado
 from urllib.parse import urlparse
+import warnings
 import zlib
 
 from tornado.concurrent import Future, future_set_result_unless_cancelled
@@ -1410,6 +1411,15 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
             104857600,
         )
 
+    def __del__(self) -> None:
+        if self.protocol is not None:
+            # Unclosed client connections can sometimes log "task was destroyed but
+            # was pending" warnings if shutdown strikes at the wrong time (such as
+            # while a ping is being processed due to ping_interval). Log our own
+            # warning to make it a little more deterministic (although it's still
+            # dependent on GC timing).
+            warnings.warn("Unclosed WebSocketClientConnection", ResourceWarning)
+
     def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None:
         """Closes the websocket connection.