]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Catch StreamClosedErrors in WebSocketHandler and abort.
authorBen Darnell <ben@bendarnell.com>
Thu, 6 Jun 2013 01:43:16 +0000 (21:43 -0400)
committerBen Darnell <ben@bendarnell.com>
Thu, 6 Jun 2013 01:43:16 +0000 (21:43 -0400)
When the stream is closed with buffered data, the close callback won't
be run until all buffered data is consumed, but any attempt to write
to the stream will fail, as will reading past the end of the buffer.
This requires a try/except around each read or write, analogous to the
one introduced in HTTPServer in commit 3258726f.

Closes #604.
Closes #661.

tornado/test/websocket_test.py
tornado/websocket.py

index f416fea4b87e8a37d193b392e089766a75ec85b5..0c5a474790b6bdd708699d308c6593705750bc62 100644 (file)
@@ -1,3 +1,5 @@
+from tornado.concurrent import Future
+from tornado import gen
 from tornado.httpclient import HTTPError
 from tornado.log import gen_log
 from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
@@ -6,9 +8,15 @@ from tornado.websocket import WebSocketHandler, websocket_connect, WebSocketErro
 
 
 class EchoHandler(WebSocketHandler):
+    def initialize(self, close_future):
+        self.close_future = close_future
+
     def on_message(self, message):
         self.write_message(message, isinstance(message, bytes))
 
+    def on_close(self):
+        self.close_future.set_result(None)
+
 
 class NonWebSocketHandler(RequestHandler):
     def get(self):
@@ -17,8 +25,9 @@ class NonWebSocketHandler(RequestHandler):
 
 class WebSocketTest(AsyncHTTPTestCase):
     def get_app(self):
+        self.close_future = Future()
         return Application([
-            ('/echo', EchoHandler),
+            ('/echo', EchoHandler, dict(close_future=self.close_future)),
             ('/non_ws', NonWebSocketHandler),
         ])
 
@@ -67,3 +76,12 @@ class WebSocketTest(AsyncHTTPTestCase):
                     io_loop=self.io_loop,
                     connect_timeout=0.01)
         self.assertEqual(cm.exception.code, 599)
+
+    @gen_test
+    def test_websocket_close_buffered_data(self):
+        ws = yield websocket_connect(
+            'ws://localhost:%d/echo' % self.get_http_port())
+        ws.write_message('hello')
+        ws.write_message('world')
+        ws.stream.close()
+        yield self.close_future
index 8435e28a23dfc53dafeb8ecbbe7d41467393079b..1eef4019b4c6e74c2f5847a311fb966762c40b63 100644 (file)
@@ -35,6 +35,7 @@ from tornado.concurrent import Future
 from tornado.escape import utf8, native_str
 from tornado import httpclient
 from tornado.ioloop import IOLoop
+from tornado.iostream import StreamClosedError
 from tornado.log import gen_log, app_log
 from tornado.netutil import Resolver
 from tornado import simple_httpclient
@@ -588,7 +589,10 @@ class WebSocketProtocol13(WebSocketProtocol):
             opcode = 0x1
         message = tornado.escape.utf8(message)
         assert isinstance(message, bytes_type)
-        self._write_frame(True, opcode, message)
+        try:
+            self._write_frame(True, opcode, message)
+        except StreamClosedError:
+            self._abort()
 
     def write_ping(self, data):
         """Send ping frame."""
@@ -596,7 +600,10 @@ class WebSocketProtocol13(WebSocketProtocol):
         self._write_frame(True, 0x9, data)
 
     def _receive_frame(self):
-        self.stream.read_bytes(2, self._on_frame_start)
+        try:
+            self.stream.read_bytes(2, self._on_frame_start)
+        except StreamClosedError:
+            self._abort()
 
     def _on_frame_start(self, data):
         header, payloadlen = struct.unpack("BB", data)
@@ -614,34 +621,46 @@ class WebSocketProtocol13(WebSocketProtocol):
             # control frames must have payload < 126
             self._abort()
             return
-        if payloadlen < 126:
-            self._frame_length = payloadlen
+        try:
+            if payloadlen < 126:
+                self._frame_length = payloadlen
+                if self._masked_frame:
+                    self.stream.read_bytes(4, self._on_masking_key)
+                else:
+                    self.stream.read_bytes(self._frame_length, self._on_frame_data)
+            elif payloadlen == 126:
+                self.stream.read_bytes(2, self._on_frame_length_16)
+            elif payloadlen == 127:
+                self.stream.read_bytes(8, self._on_frame_length_64)
+        except StreamClosedError:
+            self._abort()
+
+    def _on_frame_length_16(self, data):
+        self._frame_length = struct.unpack("!H", data)[0]
+        try:
             if self._masked_frame:
                 self.stream.read_bytes(4, self._on_masking_key)
             else:
                 self.stream.read_bytes(self._frame_length, self._on_frame_data)
-        elif payloadlen == 126:
-            self.stream.read_bytes(2, self._on_frame_length_16)
-        elif payloadlen == 127:
-            self.stream.read_bytes(8, self._on_frame_length_64)
-
-    def _on_frame_length_16(self, data):
-        self._frame_length = struct.unpack("!H", data)[0]
-        if self._masked_frame:
-            self.stream.read_bytes(4, self._on_masking_key)
-        else:
-            self.stream.read_bytes(self._frame_length, self._on_frame_data)
+        except StreamClosedError:
+            self._abort()
 
     def _on_frame_length_64(self, data):
         self._frame_length = struct.unpack("!Q", data)[0]
-        if self._masked_frame:
-            self.stream.read_bytes(4, self._on_masking_key)
-        else:
-            self.stream.read_bytes(self._frame_length, self._on_frame_data)
+        try:
+            if self._masked_frame:
+                self.stream.read_bytes(4, self._on_masking_key)
+            else:
+                self.stream.read_bytes(self._frame_length, self._on_frame_data)
+        except StreamClosedError:
+            self._abort()
 
     def _on_masking_key(self, data):
         self._frame_mask = data
-        self.stream.read_bytes(self._frame_length, self._on_masked_frame_data)
+        try:
+            self.stream.read_bytes(self._frame_length, self._on_masked_frame_data)
+        except StreamClosedError:
+            self._abort()
 
     def _apply_mask(self, mask, data):
         mask = array.array("B", mask)