]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
websocket: Expand testing of next-ping calculation 3513/head
authorBen Darnell <ben@bendarnell.com>
Thu, 24 Jul 2025 20:37:48 +0000 (20:37 +0000)
committerBen Darnell <ben@bendarnell.com>
Thu, 24 Jul 2025 20:37:48 +0000 (20:37 +0000)
Includes end-to-end tests that the correct number of pings are sent
(piggybacking on an existing test) and a unit test for the
`ping_sleep_time` calculation.

tornado/test/websocket_test.py
tornado/websocket.py

index 1f317fb0d5ac6dfaaafbce447cab4fd10af0ec3b..494c4bf67ab5b01b2975a33ebea0b82e5d71db66 100644 (file)
@@ -1,5 +1,6 @@
 import asyncio
 import contextlib
+import datetime
 import functools
 import socket
 import traceback
@@ -861,16 +862,21 @@ class ServerPingTimeoutTest(WebSocketBaseTestCase):
         return app
 
     @staticmethod
-    def suppress_pong(ws):
-        """Suppress the client's "pong" response."""
+    def install_hook(ws):
+        """Optionally suppress the client's "pong" response."""
+
+        ws.drop_pongs = False
+        ws.pongs_received = 0
 
         def wrapper(fcn):
-            def _inner(oppcode: int, data: bytes):
-                if oppcode == 0xA:  # NOTE: 0x9=ping, 0xA=pong
-                    # prevent pong responses
-                    return
+            def _inner(opcode: int, data: bytes):
+                if opcode == 0xA:  # NOTE: 0x9=ping, 0xA=pong
+                    ws.pongs_received += 1
+                    if ws.drop_pongs:
+                        # prevent pong responses
+                        return
                 # leave all other responses unchanged
-                return fcn(oppcode, data)
+                return fcn(opcode, data)
 
             return _inner
 
@@ -883,13 +889,14 @@ class ServerPingTimeoutTest(WebSocketBaseTestCase):
         ws = yield self.ws_connect(
             "/", ping_interval=interval, ping_timeout=interval / 4
         )
+        self.install_hook(ws)
 
         # websocket handler (server side)
         handler = self.handlers[0]
 
         for _ in range(5):
             # wait for the ping period
-            yield gen.sleep(0.2)
+            yield gen.sleep(interval)
 
             # connection should still be open from the server end
             self.assertIsNone(handler.close_code)
@@ -898,8 +905,12 @@ class ServerPingTimeoutTest(WebSocketBaseTestCase):
             # connection should still be open from the client end
             assert ws.protocol.close_code is None
 
+        # Check that our hook is intercepting messages; allow for
+        # some variance in timing (due to e.g. cpu load)
+        self.assertGreaterEqual(ws.pongs_received, 4)
+
         # suppress the pong response message
-        self.suppress_pong(ws)
+        ws.drop_pongs = True
 
         # give the server time to register this
         yield gen.sleep(interval * 1.5)
@@ -912,6 +923,23 @@ class ServerPingTimeoutTest(WebSocketBaseTestCase):
         self.assertEqual(ws.protocol.close_code, 1000)
 
 
+class PingCalculationTest(unittest.TestCase):
+    def test_ping_sleep_time(self):
+        from tornado.websocket import WebSocketProtocol13
+
+        now = datetime.datetime(2025, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc)
+        interval = 10  # seconds
+        last_ping_time = datetime.datetime(
+            2025, 1, 1, 11, 59, 54, tzinfo=datetime.timezone.utc
+        )
+        sleep_time = WebSocketProtocol13.ping_sleep_time(
+            last_ping_time=last_ping_time.timestamp(),
+            interval=interval,
+            now=now.timestamp(),
+        )
+        self.assertEqual(sleep_time, 4)
+
+
 class ManualPingTest(WebSocketBaseTestCase):
     def get_app(self):
         class PingHandler(TestWebSocketHandler):
index 2e40636925ab15fa345a1035948b0c9b83513a50..c2e18d218d7fbe1ead6d3866ee2f382a6f23205b 100644 (file)
@@ -1346,6 +1346,11 @@ class WebSocketProtocol13(WebSocketProtocol):
         ):
             self._ping_coroutine = asyncio.create_task(self.periodic_ping())
 
+    @staticmethod
+    def ping_sleep_time(*, last_ping_time: float, interval: float, now: float) -> float:
+        """Calculate the sleep time until the next ping should be sent."""
+        return max(0, last_ping_time + interval - now)
+
     async def periodic_ping(self) -> None:
         """Send a ping and wait for a pong if ping_timeout is configured.
 
@@ -1371,7 +1376,13 @@ class WebSocketProtocol13(WebSocketProtocol):
                 return
 
             # wait until the next scheduled ping
-            await asyncio.sleep(ping_time + interval - IOLoop.current().time())
+            await asyncio.sleep(
+                self.ping_sleep_time(
+                    last_ping_time=ping_time,
+                    interval=interval,
+                    now=IOLoop.current().time(),
+                )
+            )
 
 
 class WebSocketClientConnection(simple_httpclient._HTTPConnection):