From: Ben Darnell Date: Sat, 20 May 2017 16:09:58 +0000 (-0400) Subject: websocket: Don't swallow exceptions in _write_frame X-Git-Tag: v5.0.0~86^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=73c48f328c247a40c444a9e0fdee65157231fff7;p=thirdparty%2Ftornado.git websocket: Don't swallow exceptions in _write_frame Swallowing the exception violated the method's interface (by returning None instead of a Future), and differs from stream-closed behavior in other contexts in Tornado. Fixes #1980 --- diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index d47a74e65..e0b5573de 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -7,6 +7,7 @@ import traceback from tornado.concurrent import Future from tornado import gen from tornado.httpclient import HTTPError, HTTPRequest +from tornado.iostream import StreamClosedError from tornado.log import gen_log, app_log from tornado.template import DictLoader from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog @@ -50,7 +51,10 @@ class TestWebSocketHandler(WebSocketHandler): class EchoHandler(TestWebSocketHandler): def on_message(self, message): - self.write_message(message, isinstance(message, bytes)) + try: + self.write_message(message, isinstance(message, bytes)) + except StreamClosedError: + pass class ErrorInOnMessageHandler(TestWebSocketHandler): @@ -327,6 +331,14 @@ class WebSocketTest(WebSocketBaseTestCase): self.assertEqual(code, 1001) self.assertEqual(reason, 'goodbye') + @gen_test + def test_write_after_close(self): + ws = yield self.ws_connect('/close_reason') + msg = yield ws.read_message() + self.assertIs(msg, None) + with self.assertRaises(StreamClosedError): + ws.write_message('hello') + @gen_test def test_async_prepare(self): # Previously, an async prepare method triggered a bug that would diff --git a/tornado/websocket.py b/tornado/websocket.py index 69437ee4e..7600910c0 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -764,10 +764,7 @@ class WebSocketProtocol13(WebSocketProtocol): data = mask + _websocket_mask(mask, data) frame += data self._wire_bytes_out += len(frame) - try: - return self.stream.write(frame) - except StreamClosedError: - self._abort() + return self.stream.write(frame) def write_message(self, message, binary=False): """Sends the given message to the client of this Web Socket.""" @@ -951,7 +948,10 @@ class WebSocketProtocol13(WebSocketProtocol): self.close(self.handler.close_code) elif opcode == 0x9: # Ping - self._write_frame(True, 0xA, data) + try: + self._write_frame(True, 0xA, data) + except StreamClosedError: + self._abort() self._run_callback(self.handler.on_ping, data) elif opcode == 0xA: # Pong @@ -972,7 +972,10 @@ class WebSocketProtocol13(WebSocketProtocol): close_data = struct.pack('>H', code) if reason is not None: close_data += utf8(reason) - self._write_frame(True, 0x8, close_data) + try: + self._write_frame(True, 0x8, close_data) + except StreamClosedError: + self._abort() self.server_terminated = True if self.client_terminated: if self._waiting is not None: