]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
websocket: Support periodic pings from client side 1957/head
authorBen Darnell <ben@bendarnell.com>
Tue, 21 Feb 2017 00:32:54 +0000 (19:32 -0500)
committerBen Darnell <ben@bendarnell.com>
Tue, 21 Feb 2017 00:32:54 +0000 (19:32 -0500)
tornado/test/websocket_test.py
tornado/websocket.py

index b32f7e06da60eb400c14852f8bb7981f7120a09c..91f6692a9b304d60fe769a50aa0206e2f67c562a 100644 (file)
@@ -94,10 +94,10 @@ class PathArgsHandler(TestWebSocketHandler):
 
 class WebSocketBaseTestCase(AsyncHTTPTestCase):
     @gen.coroutine
-    def ws_connect(self, path, compression_options=None):
+    def ws_connect(self, path, **kwargs):
         ws = yield websocket_connect(
             'ws://127.0.0.1:%d%s' % (self.get_http_port(), path),
-            compression_options=compression_options)
+            **kwargs)
         raise gen.Return(ws)
 
     @gen.coroutine
@@ -450,3 +450,24 @@ class ServerPeriodicPingTest(WebSocketBaseTestCase):
             self.assertEqual(response, "got pong")
         yield self.close(ws)
         # TODO: test that the connection gets closed if ping responses stop.
+
+
+class ClientPeriodicPingTest(WebSocketBaseTestCase):
+    def get_app(self):
+        class PingHandler(TestWebSocketHandler):
+            def on_ping(self, data):
+                self.write_message("got ping")
+
+        self.close_future = Future()
+        return Application([
+            ('/', PingHandler, dict(close_future=self.close_future)),
+        ])
+
+    @gen_test
+    def test_client_ping(self):
+        ws = yield self.ws_connect('/', ping_interval=0.01)
+        for i in range(3):
+            response = yield ws.read_message()
+            self.assertEqual(response, "got ping")
+        yield self.close(ws)
+        # TODO: test that the connection gets closed if ping responses stop.
index dcd7e73a382a1d501af6775ce1c1bba39e8011a1..e1a0421f79c84ef5687e9b718d013d2625404171 100644 (file)
@@ -197,7 +197,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
 
         Set ws_ping_interval = 0 to disable pings.
         """
-        return self.settings.get('websocket_ping_interval', 0)
+        return self.settings.get('websocket_ping_interval', None)
 
     @property
     def ping_timeout(self):
@@ -205,9 +205,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
         close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
         Default is max of 3 pings or 30 seconds.
         """
-        return self.settings.get('websocket_ping_timeout',
-            max(3 * self.ping_interval, 30)
-        )
+        return self.settings.get('websocket_ping_timeout', None)
 
     def write_message(self, message, binary=False):
         """Sends the given message to the client of this Web Socket.
@@ -914,12 +912,26 @@ class WebSocketProtocol13(WebSocketProtocol):
             self._waiting = self.stream.io_loop.add_timeout(
                 self.stream.io_loop.time() + 5, self._abort)
 
+    @property
+    def ping_interval(self):
+        interval = self.handler.ping_interval
+        if interval is not None:
+            return interval
+        return 0
+
+    @property
+    def ping_timeout(self):
+        timeout = self.handler.ping_timeout
+        if timeout is not None:
+            return timeout
+        return max(3 * self.ping_interval, 30)
+
     def start_pinging(self):
         """Start sending periodic pings to keep the connection alive"""
-        if self.handler.ping_interval > 0:
+        if self.ping_interval > 0:
             self.last_ping = self.last_pong = IOLoop.current().time()
             self.ping_callback = PeriodicCallback(
-                self.periodic_ping, self.handler.ping_interval*1000)
+                self.periodic_ping, self.ping_interval*1000)
             self.ping_callback.start()
 
     def periodic_ping(self):
@@ -937,8 +949,8 @@ class WebSocketProtocol13(WebSocketProtocol):
         now = IOLoop.current().time()
         since_last_pong = now - self.last_pong
         since_last_ping = now - self.last_ping
-        if (since_last_ping < 2*self.handler.ping_interval and
-                since_last_pong > self.handler.ping_timeout):
+        if (since_last_ping < 2*self.ping_interval and
+                since_last_pong > self.ping_timeout):
             self.close()
             return
 
@@ -953,7 +965,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
     `websocket_connect` function instead.
     """
     def __init__(self, io_loop, request, on_message_callback=None,
-                 compression_options=None):
+                 compression_options=None, ping_interval=None, ping_timeout=None):
         self.compression_options = compression_options
         self.connect_future = TracebackFuture()
         self.protocol = None
@@ -962,6 +974,8 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
         self.key = base64.b64encode(os.urandom(16))
         self._on_message_callback = on_message_callback
         self.close_code = self.close_reason = None
+        self.ping_interval = ping_interval
+        self.ping_timeout = ping_timeout
 
         scheme, sep, rest = request.url.partition(':')
         scheme = {'ws': 'http', 'wss': 'https'}[scheme]
@@ -1025,6 +1039,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
         self.headers = headers
         self.protocol = self.get_websocket_protocol()
         self.protocol._process_server_headers(self.key, self.headers)
+        self.protocol.start_pinging()
         self.protocol._receive_frame()
 
         if self._timeout is not None:
@@ -1087,7 +1102,8 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
 
 
 def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None,
-                      on_message_callback=None, compression_options=None):
+                      on_message_callback=None, compression_options=None,
+                      ping_interval=None, ping_timeout=None):
     """Client-side websocket support.
 
     Takes a url and returns a Future whose result is a
@@ -1131,7 +1147,9 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None,
         request, httpclient.HTTPRequest._DEFAULTS)
     conn = WebSocketClientConnection(io_loop, request,
                                      on_message_callback=on_message_callback,
-                                     compression_options=compression_options)
+                                     compression_options=compression_options,
+                                     ping_interval=ping_interval,
+                                     ping_timeout=ping_timeout)
     if callback is not None:
         io_loop.add_future(conn.connect_future, callback)
     return conn.connect_future