]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
websocket: Don't swallow exceptions in _write_frame 2045/head
authorBen Darnell <ben@bendarnell.com>
Sat, 20 May 2017 16:09:58 +0000 (12:09 -0400)
committerBen Darnell <ben@bendarnell.com>
Sat, 20 May 2017 17:26:23 +0000 (13:26 -0400)
Swallowing the exception violated the method's interface (by returning
None instead of a Future), and differs from stream-closed behavior in
other contexts in Tornado.

Fixes #1980

tornado/test/websocket_test.py
tornado/websocket.py

index d47a74e651e86ab8ecd8ab76d6936c23edd06e05..e0b5573de311d757c6a2c39d3ed0d0fe8a89c726 100644 (file)
@@ -7,6 +7,7 @@ import traceback
 from tornado.concurrent import Future
 from tornado import gen
 from tornado.httpclient import HTTPError, HTTPRequest
+from tornado.iostream import StreamClosedError
 from tornado.log import gen_log, app_log
 from tornado.template import DictLoader
 from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
@@ -50,7 +51,10 @@ class TestWebSocketHandler(WebSocketHandler):
 
 class EchoHandler(TestWebSocketHandler):
     def on_message(self, message):
-        self.write_message(message, isinstance(message, bytes))
+        try:
+            self.write_message(message, isinstance(message, bytes))
+        except StreamClosedError:
+            pass
 
 
 class ErrorInOnMessageHandler(TestWebSocketHandler):
@@ -327,6 +331,14 @@ class WebSocketTest(WebSocketBaseTestCase):
         self.assertEqual(code, 1001)
         self.assertEqual(reason, 'goodbye')
 
+    @gen_test
+    def test_write_after_close(self):
+        ws = yield self.ws_connect('/close_reason')
+        msg = yield ws.read_message()
+        self.assertIs(msg, None)
+        with self.assertRaises(StreamClosedError):
+            ws.write_message('hello')
+
     @gen_test
     def test_async_prepare(self):
         # Previously, an async prepare method triggered a bug that would
index 69437ee4e3d9cc335b0eb1b785a3565fff14ab70..7600910c052f07b536137b13f71d244f3b9d58c1 100644 (file)
@@ -764,10 +764,7 @@ class WebSocketProtocol13(WebSocketProtocol):
             data = mask + _websocket_mask(mask, data)
         frame += data
         self._wire_bytes_out += len(frame)
-        try:
-            return self.stream.write(frame)
-        except StreamClosedError:
-            self._abort()
+        return self.stream.write(frame)
 
     def write_message(self, message, binary=False):
         """Sends the given message to the client of this Web Socket."""
@@ -951,7 +948,10 @@ class WebSocketProtocol13(WebSocketProtocol):
             self.close(self.handler.close_code)
         elif opcode == 0x9:
             # Ping
-            self._write_frame(True, 0xA, data)
+            try:
+                self._write_frame(True, 0xA, data)
+            except StreamClosedError:
+                self._abort()
             self._run_callback(self.handler.on_ping, data)
         elif opcode == 0xA:
             # Pong
@@ -972,7 +972,10 @@ class WebSocketProtocol13(WebSocketProtocol):
                     close_data = struct.pack('>H', code)
                 if reason is not None:
                     close_data += utf8(reason)
-                self._write_frame(True, 0x8, close_data)
+                try:
+                    self._write_frame(True, 0x8, close_data)
+                except StreamClosedError:
+                    self._abort()
             self.server_terminated = True
         if self.client_terminated:
             if self._waiting is not None: