]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
websockets: fix ping_timeout (#3376)
authorOliver Sanders <oliver.sanders@metoffice.gov.uk>
Tue, 22 Apr 2025 17:19:00 +0000 (18:19 +0100)
committerGitHub <noreply@github.com>
Tue, 22 Apr 2025 17:19:00 +0000 (13:19 -0400)
* websockets: fix ping_timeout

* Closes #3258
* Closes #2905
* Closes #2655
* Fixes an issue with the calculation of ping timeout interval that
  could cause connections to be erroneously timed out and closed
  from the server end.

* websocket: Fix lint, remove hard-coded 30s default timeout

* websocket_test: Improve assertion error messages

* websocket_test: Allow a little slack in ping timing

Appears to be necessary on windows.

---------

Co-authored-by: Ben Darnell <ben@bendarnell.com>
tornado/test/websocket_test.py
tornado/websocket.py

index 6d95084bd9908046f8b17933c2b36296bca59ed6..097dca85a14c60fdcf863b2592e33aecc664a706 100644 (file)
@@ -810,7 +810,11 @@ class ServerPeriodicPingTest(WebSocketBaseTestCase):
             def on_pong(self, data):
                 self.write_message("got pong")
 
-        return Application([("/", PingHandler)], websocket_ping_interval=0.01)
+        return Application(
+            [("/", PingHandler)],
+            websocket_ping_interval=0.01,
+            websocket_ping_timeout=0,
+        )
 
     @gen_test
     def test_server_ping(self):
@@ -831,14 +835,82 @@ class ClientPeriodicPingTest(WebSocketBaseTestCase):
 
     @gen_test
     def test_client_ping(self):
-        ws = yield self.ws_connect("/", ping_interval=0.01)
+        ws = yield self.ws_connect("/", ping_interval=0.01, ping_timeout=0)
         for i in range(3):
             response = yield ws.read_message()
             self.assertEqual(response, "got ping")
-        # TODO: test that the connection gets closed if ping responses stop.
         ws.close()
 
 
+class ServerPingTimeoutTest(WebSocketBaseTestCase):
+    def get_app(self):
+        self.handlers: list[WebSocketHandler] = []
+        test = self
+
+        class PingHandler(TestWebSocketHandler):
+            def initialize(self, close_future=None, compression_options=None):
+                self.handlers = test.handlers
+                # capture the handler instance so we can interrogate it later
+                self.handlers.append(self)
+                return super().initialize(
+                    close_future=close_future, compression_options=compression_options
+                )
+
+        app = Application([("/", PingHandler)])
+        return app
+
+    @staticmethod
+    def suppress_pong(ws):
+        """Suppress the client's "pong" response."""
+
+        def wrapper(fcn):
+            def _inner(oppcode: int, data: bytes):
+                if oppcode == 0xA:  # NOTE: 0x9=ping, 0xA=pong
+                    # prevent pong responses
+                    return
+                # leave all other responses unchanged
+                return fcn(oppcode, data)
+
+            return _inner
+
+        ws.protocol._handle_message = wrapper(ws.protocol._handle_message)
+
+    @gen_test
+    def test_client_ping_timeout(self):
+        # websocket client
+        interval = 0.2
+        ws = yield self.ws_connect(
+            "/", ping_interval=interval, ping_timeout=interval / 4
+        )
+
+        # websocket handler (server side)
+        handler = self.handlers[0]
+
+        for _ in range(5):
+            # wait for the ping period
+            yield gen.sleep(0.2)
+
+            # connection should still be open from the server end
+            self.assertIsNone(handler.close_code)
+            self.assertIsNone(handler.close_reason)
+
+            # connection should still be open from the client end
+            assert ws.protocol.close_code is None
+
+        # suppress the pong response message
+        self.suppress_pong(ws)
+
+        # give the server time to register this
+        yield gen.sleep(interval * 1.5)
+
+        # connection should be closed from the server side
+        self.assertEqual(handler.close_code, 1000)
+        self.assertEqual(handler.close_reason, "ping timed out")
+
+        # client should have received a close operation
+        self.assertEqual(ws.protocol.close_code, 1000)
+
+
 class ManualPingTest(WebSocketBaseTestCase):
     def get_app(self):
         class PingHandler(TestWebSocketHandler):
index 1e0161e1b8be5526fb8202a8202cb7b470863ba2..4fbb2da12ab30e46c29e5308f46a45433be0000b 100644 (file)
@@ -14,7 +14,9 @@ defined in `RFC 6455 <http://tools.ietf.org/html/rfc6455>`_.
 import abc
 import asyncio
 import base64
+import functools
 import hashlib
+import logging
 import os
 import sys
 import struct
@@ -26,7 +28,7 @@ import zlib
 from tornado.concurrent import Future, future_set_result_unless_cancelled
 from tornado.escape import utf8, native_str, to_unicode
 from tornado import gen, httpclient, httputil
-from tornado.ioloop import IOLoop, PeriodicCallback
+from tornado.ioloop import IOLoop
 from tornado.iostream import StreamClosedError, IOStream
 from tornado.log import gen_log, app_log
 from tornado.netutil import Resolver
@@ -97,6 +99,9 @@ if TYPE_CHECKING:
 
 _default_max_message_size = 10 * 1024 * 1024
 
+# log to "gen_log" but suppress duplicate log messages
+de_dupe_gen_log = functools.lru_cache(gen_log.log)
+
 
 class WebSocketError(Exception):
     pass
@@ -274,17 +279,41 @@ class WebSocketHandler(tornado.web.RequestHandler):
 
     @property
     def ping_interval(self) -> Optional[float]:
-        """The interval for websocket keep-alive pings.
+        """The interval for sending websocket pings.
+
+        If this is non-zero, the websocket will send a ping every
+        ping_interval seconds.
+        The client will respond with a "pong". The connection can be configured
+        to timeout on late pong delivery using ``websocket_ping_timeout``.
 
-        Set websocket_ping_interval = 0 to disable pings.
+        Set ``websocket_ping_interval = 0`` to disable pings.
+
+        Default: ``0``
         """
         return self.settings.get("websocket_ping_interval", None)
 
     @property
     def ping_timeout(self) -> Optional[float]:
-        """If no ping is received in this many seconds,
-        close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
-        Default is max of 3 pings or 30 seconds.
+        """Timeout if no pong is received in this many seconds.
+
+        To be used in combination with ``websocket_ping_interval > 0``.
+        If a ping response (a "pong") is not received within
+        ``websocket_ping_timeout`` seconds, then the websocket connection
+        will be closed.
+
+        This can help to clean up clients which have disconnected without
+        cleanly closing the websocket connection.
+
+        Note, the ping timeout cannot be longer than the ping interval.
+
+        Set ``websocket_ping_timeout = 0`` to disable the ping timeout.
+
+        Default: ``min(ping_interval, 30)``
+
+        .. versionchanged:: 6.5.0
+           Default changed from the max of 3 pings or 30 seconds.
+           The ping timeout can no longer be configured longer than the
+           ping interval.
         """
         return self.settings.get("websocket_ping_timeout", None)
 
@@ -831,11 +860,10 @@ class WebSocketProtocol13(WebSocketProtocol):
         # the effect of compression, frame overhead, and control frames.
         self._wire_bytes_in = 0
         self._wire_bytes_out = 0
-        self.ping_callback = None  # type: Optional[PeriodicCallback]
-        self.last_ping = 0.0
-        self.last_pong = 0.0
+        self._received_pong = False  # type: bool
         self.close_code = None  # type: Optional[int]
         self.close_reason = None  # type: Optional[str]
+        self._ping_coroutine = None  # type: Optional[asyncio.Task]
 
     # Use a property for this to satisfy the abc.
     @property
@@ -1232,7 +1260,7 @@ class WebSocketProtocol13(WebSocketProtocol):
             self._run_callback(self.handler.on_ping, data)
         elif opcode == 0xA:
             # Pong
-            self.last_pong = IOLoop.current().time()
+            self._received_pong = True
             return self._run_callback(self.handler.on_pong, data)
         else:
             self._abort()
@@ -1266,9 +1294,9 @@ class WebSocketProtocol13(WebSocketProtocol):
             self._waiting = self.stream.io_loop.add_timeout(
                 self.stream.io_loop.time() + 5, self._abort
             )
-        if self.ping_callback:
-            self.ping_callback.stop()
-            self.ping_callback = None
+        if self._ping_coroutine:
+            self._ping_coroutine.cancel()
+            self._ping_coroutine = None
 
     def is_closing(self) -> bool:
         """Return ``True`` if this connection is closing.
@@ -1279,60 +1307,69 @@ class WebSocketProtocol13(WebSocketProtocol):
         """
         return self.stream.closed() or self.client_terminated or self.server_terminated
 
+    def set_nodelay(self, x: bool) -> None:
+        self.stream.set_nodelay(x)
+
     @property
-    def ping_interval(self) -> Optional[float]:
+    def ping_interval(self) -> float:
         interval = self.params.ping_interval
         if interval is not None:
             return interval
         return 0
 
     @property
-    def ping_timeout(self) -> Optional[float]:
+    def ping_timeout(self) -> float:
         timeout = self.params.ping_timeout
         if timeout is not None:
+            if self.ping_interval and timeout > self.ping_interval:
+                de_dupe_gen_log(
+                    # Note: using de_dupe_gen_log to prevent this message from
+                    # being duplicated for each connection
+                    logging.WARNING,
+                    f"The websocket_ping_timeout ({timeout}) cannot be longer"
+                    f" than the websocket_ping_interval ({self.ping_interval})."
+                    f"\nSetting websocket_ping_timeout={self.ping_interval}",
+                )
+                return self.ping_interval
             return timeout
-        assert self.ping_interval is not None
-        return max(3 * self.ping_interval, 30)
+        return self.ping_interval
 
     def start_pinging(self) -> None:
         """Start sending periodic pings to keep the connection alive"""
-        assert self.ping_interval is not None
-        if self.ping_interval > 0:
-            self.last_ping = self.last_pong = IOLoop.current().time()
-            self.ping_callback = PeriodicCallback(
-                self.periodic_ping, self.ping_interval * 1000
-            )
-            self.ping_callback.start()
+        if (
+            # prevent multiple ping coroutines being run in parallel
+            not self._ping_coroutine
+            # only run the ping coroutine if a ping interval is configured
+            and self.ping_interval > 0
+        ):
+            self._ping_coroutine = asyncio.create_task(self.periodic_ping())
 
-    def periodic_ping(self) -> None:
-        """Send a ping to keep the websocket alive
+    async def periodic_ping(self) -> None:
+        """Send a ping and wait for a pong if ping_timeout is configured.
 
         Called periodically if the websocket_ping_interval is set and non-zero.
         """
-        if self.is_closing() and self.ping_callback is not None:
-            self.ping_callback.stop()
-            return
+        interval = self.ping_interval
+        timeout = self.ping_timeout
 
-        # Check for timeout on pong. Make sure that we really have
-        # sent a recent ping in case the machine with both server and
-        # client has been suspended since the last ping.
-        now = IOLoop.current().time()
-        since_last_pong = now - self.last_pong
-        since_last_ping = now - self.last_ping
-        assert self.ping_interval is not None
-        assert self.ping_timeout is not None
-        if (
-            since_last_ping < 2 * self.ping_interval
-            and since_last_pong > self.ping_timeout
-        ):
-            self.close()
-            return
+        await asyncio.sleep(interval)
 
-        self.write_ping(b"")
-        self.last_ping = now
+        while True:
+            # send a ping
+            self._received_pong = False
+            ping_time = IOLoop.current().time()
+            self.write_ping(b"")
 
-    def set_nodelay(self, x: bool) -> None:
-        self.stream.set_nodelay(x)
+            # wait until the ping timeout
+            await asyncio.sleep(timeout)
+
+            # make sure we received a pong within the timeout
+            if timeout > 0 and not self._received_pong:
+                self.close(reason="ping timed out")
+                return
+
+            # wait until the next scheduled ping
+            await asyncio.sleep(IOLoop.current().time() - ping_time + interval)
 
 
 class WebSocketClientConnection(simple_httpclient._HTTPConnection):