+from tornado.concurrent import Future
+from tornado import gen
from tornado.httpclient import HTTPError
from tornado.log import gen_log
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
class EchoHandler(WebSocketHandler):
+ def initialize(self, close_future):
+ self.close_future = close_future
+
def on_message(self, message):
self.write_message(message, isinstance(message, bytes))
+ def on_close(self):
+ self.close_future.set_result(None)
+
class NonWebSocketHandler(RequestHandler):
def get(self):
class WebSocketTest(AsyncHTTPTestCase):
def get_app(self):
+ self.close_future = Future()
return Application([
- ('/echo', EchoHandler),
+ ('/echo', EchoHandler, dict(close_future=self.close_future)),
('/non_ws', NonWebSocketHandler),
])
io_loop=self.io_loop,
connect_timeout=0.01)
self.assertEqual(cm.exception.code, 599)
+
+ @gen_test
+ def test_websocket_close_buffered_data(self):
+ ws = yield websocket_connect(
+ 'ws://localhost:%d/echo' % self.get_http_port())
+ ws.write_message('hello')
+ ws.write_message('world')
+ ws.stream.close()
+ yield self.close_future
from tornado.escape import utf8, native_str
from tornado import httpclient
from tornado.ioloop import IOLoop
+from tornado.iostream import StreamClosedError
from tornado.log import gen_log, app_log
from tornado.netutil import Resolver
from tornado import simple_httpclient
opcode = 0x1
message = tornado.escape.utf8(message)
assert isinstance(message, bytes_type)
- self._write_frame(True, opcode, message)
+ try:
+ self._write_frame(True, opcode, message)
+ except StreamClosedError:
+ self._abort()
def write_ping(self, data):
"""Send ping frame."""
self._write_frame(True, 0x9, data)
def _receive_frame(self):
- self.stream.read_bytes(2, self._on_frame_start)
+ try:
+ self.stream.read_bytes(2, self._on_frame_start)
+ except StreamClosedError:
+ self._abort()
def _on_frame_start(self, data):
header, payloadlen = struct.unpack("BB", data)
# control frames must have payload < 126
self._abort()
return
- if payloadlen < 126:
- self._frame_length = payloadlen
+ try:
+ if payloadlen < 126:
+ self._frame_length = payloadlen
+ if self._masked_frame:
+ self.stream.read_bytes(4, self._on_masking_key)
+ else:
+ self.stream.read_bytes(self._frame_length, self._on_frame_data)
+ elif payloadlen == 126:
+ self.stream.read_bytes(2, self._on_frame_length_16)
+ elif payloadlen == 127:
+ self.stream.read_bytes(8, self._on_frame_length_64)
+ except StreamClosedError:
+ self._abort()
+
+ def _on_frame_length_16(self, data):
+ self._frame_length = struct.unpack("!H", data)[0]
+ try:
if self._masked_frame:
self.stream.read_bytes(4, self._on_masking_key)
else:
self.stream.read_bytes(self._frame_length, self._on_frame_data)
- elif payloadlen == 126:
- self.stream.read_bytes(2, self._on_frame_length_16)
- elif payloadlen == 127:
- self.stream.read_bytes(8, self._on_frame_length_64)
-
- def _on_frame_length_16(self, data):
- self._frame_length = struct.unpack("!H", data)[0]
- if self._masked_frame:
- self.stream.read_bytes(4, self._on_masking_key)
- else:
- self.stream.read_bytes(self._frame_length, self._on_frame_data)
+ except StreamClosedError:
+ self._abort()
def _on_frame_length_64(self, data):
self._frame_length = struct.unpack("!Q", data)[0]
- if self._masked_frame:
- self.stream.read_bytes(4, self._on_masking_key)
- else:
- self.stream.read_bytes(self._frame_length, self._on_frame_data)
+ try:
+ if self._masked_frame:
+ self.stream.read_bytes(4, self._on_masking_key)
+ else:
+ self.stream.read_bytes(self._frame_length, self._on_frame_data)
+ except StreamClosedError:
+ self._abort()
def _on_masking_key(self, data):
self._frame_mask = data
- self.stream.read_bytes(self._frame_length, self._on_masked_frame_data)
+ try:
+ self.stream.read_bytes(self._frame_length, self._on_masked_frame_data)
+ except StreamClosedError:
+ self._abort()
def _apply_mask(self, mask, data):
mask = array.array("B", mask)