From dc354c57adac0c6f141677a4400e2a748d9c0e92 Mon Sep 17 00:00:00 2001 From: Ben Darnell Date: Sat, 2 Feb 2019 13:44:56 -0500 Subject: [PATCH] websocket: Catch errors in async open() correctly Previously if open() was a coroutine and raised an error, the connection would be left open. Fixes #2570 --- tornado/test/websocket_test.py | 27 +++++++++++++++++++++++++++ tornado/websocket.py | 14 +++++++++----- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index 1a5a27281..715ecf1e1 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -187,6 +187,17 @@ class OpenCoroutineHandler(TestWebSocketHandler): self.write_message("ok") +class ErrorInOpenHandler(TestWebSocketHandler): + def open(self): + raise Exception("boom") + + +class ErrorInAsyncOpenHandler(TestWebSocketHandler): + async def open(self): + await asyncio.sleep(0) + raise Exception("boom") + + class WebSocketBaseTestCase(AsyncHTTPTestCase): @gen.coroutine def ws_connect(self, path, **kwargs): @@ -245,6 +256,8 @@ class WebSocketTest(WebSocketBaseTestCase): OpenCoroutineHandler, dict(close_future=self.close_future, test=self), ), + ("/error_in_open", ErrorInOpenHandler), + ("/error_in_async_open", ErrorInAsyncOpenHandler), ], template_loader=DictLoader({"message.html": "{{ message }}"}), ) @@ -535,6 +548,20 @@ class WebSocketTest(WebSocketBaseTestCase): res = yield ws.read_message() self.assertEqual(res, "ok") + @gen_test + def test_error_in_open(self): + with ExpectLog(app_log, "Uncaught exception"): + ws = yield self.ws_connect("/error_in_open") + res = yield ws.read_message() + self.assertIsNone(res) + + @gen_test + def test_error_in_async_open(self): + with ExpectLog(app_log, "Uncaught exception"): + ws = yield self.ws_connect("/error_in_async_open") + res = yield ws.read_message() + self.assertIsNone(res) + class NativeCoroutineOnMessageHandler(TestWebSocketHandler): def initialize(self, **kwargs): diff --git a/tornado/websocket.py b/tornado/websocket.py index 50665c694..ef13751ae 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -948,11 +948,15 @@ class WebSocketProtocol13(WebSocketProtocol): self.stream = handler._detach_stream() self.start_pinging() - open_result = self._run_callback( - handler.open, *handler.open_args, **handler.open_kwargs - ) - if open_result is not None: - await open_result + try: + open_result = handler.open(*handler.open_args, **handler.open_kwargs) + if open_result is not None: + await open_result + except Exception: + handler.log_exception(*sys.exc_info()) + self._abort() + return + await self._receive_frame_loop() def _parse_extensions_header( -- 2.47.2