From: Ben Darnell Date: Tue, 21 Feb 2017 00:32:54 +0000 (-0500) Subject: websocket: Support periodic pings from client side X-Git-Tag: v4.5.0~30^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fpull%2F1957%2Fhead;p=thirdparty%2Ftornado.git websocket: Support periodic pings from client side --- diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index b32f7e06d..91f6692a9 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -94,10 +94,10 @@ class PathArgsHandler(TestWebSocketHandler): class WebSocketBaseTestCase(AsyncHTTPTestCase): @gen.coroutine - def ws_connect(self, path, compression_options=None): + def ws_connect(self, path, **kwargs): ws = yield websocket_connect( 'ws://127.0.0.1:%d%s' % (self.get_http_port(), path), - compression_options=compression_options) + **kwargs) raise gen.Return(ws) @gen.coroutine @@ -450,3 +450,24 @@ class ServerPeriodicPingTest(WebSocketBaseTestCase): self.assertEqual(response, "got pong") yield self.close(ws) # TODO: test that the connection gets closed if ping responses stop. + + +class ClientPeriodicPingTest(WebSocketBaseTestCase): + def get_app(self): + class PingHandler(TestWebSocketHandler): + def on_ping(self, data): + self.write_message("got ping") + + self.close_future = Future() + return Application([ + ('/', PingHandler, dict(close_future=self.close_future)), + ]) + + @gen_test + def test_client_ping(self): + ws = yield self.ws_connect('/', ping_interval=0.01) + for i in range(3): + response = yield ws.read_message() + self.assertEqual(response, "got ping") + yield self.close(ws) + # TODO: test that the connection gets closed if ping responses stop. diff --git a/tornado/websocket.py b/tornado/websocket.py index dcd7e73a3..e1a0421f7 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -197,7 +197,7 @@ class WebSocketHandler(tornado.web.RequestHandler): Set ws_ping_interval = 0 to disable pings. """ - return self.settings.get('websocket_ping_interval', 0) + return self.settings.get('websocket_ping_interval', None) @property def ping_timeout(self): @@ -205,9 +205,7 @@ class WebSocketHandler(tornado.web.RequestHandler): close the websocket connection (VPNs, etc. can fail to cleanly close ws connections). Default is max of 3 pings or 30 seconds. """ - return self.settings.get('websocket_ping_timeout', - max(3 * self.ping_interval, 30) - ) + return self.settings.get('websocket_ping_timeout', None) def write_message(self, message, binary=False): """Sends the given message to the client of this Web Socket. @@ -914,12 +912,26 @@ class WebSocketProtocol13(WebSocketProtocol): self._waiting = self.stream.io_loop.add_timeout( self.stream.io_loop.time() + 5, self._abort) + @property + def ping_interval(self): + interval = self.handler.ping_interval + if interval is not None: + return interval + return 0 + + @property + def ping_timeout(self): + timeout = self.handler.ping_timeout + if timeout is not None: + return timeout + return max(3 * self.ping_interval, 30) + def start_pinging(self): """Start sending periodic pings to keep the connection alive""" - if self.handler.ping_interval > 0: + if self.ping_interval > 0: self.last_ping = self.last_pong = IOLoop.current().time() self.ping_callback = PeriodicCallback( - self.periodic_ping, self.handler.ping_interval*1000) + self.periodic_ping, self.ping_interval*1000) self.ping_callback.start() def periodic_ping(self): @@ -937,8 +949,8 @@ class WebSocketProtocol13(WebSocketProtocol): now = IOLoop.current().time() since_last_pong = now - self.last_pong since_last_ping = now - self.last_ping - if (since_last_ping < 2*self.handler.ping_interval and - since_last_pong > self.handler.ping_timeout): + if (since_last_ping < 2*self.ping_interval and + since_last_pong > self.ping_timeout): self.close() return @@ -953,7 +965,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): `websocket_connect` function instead. """ def __init__(self, io_loop, request, on_message_callback=None, - compression_options=None): + compression_options=None, ping_interval=None, ping_timeout=None): self.compression_options = compression_options self.connect_future = TracebackFuture() self.protocol = None @@ -962,6 +974,8 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): self.key = base64.b64encode(os.urandom(16)) self._on_message_callback = on_message_callback self.close_code = self.close_reason = None + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout scheme, sep, rest = request.url.partition(':') scheme = {'ws': 'http', 'wss': 'https'}[scheme] @@ -1025,6 +1039,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): self.headers = headers self.protocol = self.get_websocket_protocol() self.protocol._process_server_headers(self.key, self.headers) + self.protocol.start_pinging() self.protocol._receive_frame() if self._timeout is not None: @@ -1087,7 +1102,8 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None, - on_message_callback=None, compression_options=None): + on_message_callback=None, compression_options=None, + ping_interval=None, ping_timeout=None): """Client-side websocket support. Takes a url and returns a Future whose result is a @@ -1131,7 +1147,9 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None, request, httpclient.HTTPRequest._DEFAULTS) conn = WebSocketClientConnection(io_loop, request, on_message_callback=on_message_callback, - compression_options=compression_options) + compression_options=compression_options, + ping_interval=ping_interval, + ping_timeout=ping_timeout) if callback is not None: io_loop.add_future(conn.connect_future, callback) return conn.connect_future