From: Ben Darnell Date: Sat, 2 Feb 2019 18:44:56 +0000 (-0500) Subject: websocket: Catch errors in async open() correctly X-Git-Tag: v6.0.0b1~4^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=dc354c57adac;p=thirdparty%2Ftornado.git websocket: Catch errors in async open() correctly Previously if open() was a coroutine and raised an error, the connection would be left open. Fixes #2570 --- diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index 1a5a27281..715ecf1e1 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -187,6 +187,17 @@ class OpenCoroutineHandler(TestWebSocketHandler): self.write_message("ok") +class ErrorInOpenHandler(TestWebSocketHandler): + def open(self): + raise Exception("boom") + + +class ErrorInAsyncOpenHandler(TestWebSocketHandler): + async def open(self): + await asyncio.sleep(0) + raise Exception("boom") + + class WebSocketBaseTestCase(AsyncHTTPTestCase): @gen.coroutine def ws_connect(self, path, **kwargs): @@ -245,6 +256,8 @@ class WebSocketTest(WebSocketBaseTestCase): OpenCoroutineHandler, dict(close_future=self.close_future, test=self), ), + ("/error_in_open", ErrorInOpenHandler), + ("/error_in_async_open", ErrorInAsyncOpenHandler), ], template_loader=DictLoader({"message.html": "{{ message }}"}), ) @@ -535,6 +548,20 @@ class WebSocketTest(WebSocketBaseTestCase): res = yield ws.read_message() self.assertEqual(res, "ok") + @gen_test + def test_error_in_open(self): + with ExpectLog(app_log, "Uncaught exception"): + ws = yield self.ws_connect("/error_in_open") + res = yield ws.read_message() + self.assertIsNone(res) + + @gen_test + def test_error_in_async_open(self): + with ExpectLog(app_log, "Uncaught exception"): + ws = yield self.ws_connect("/error_in_async_open") + res = yield ws.read_message() + self.assertIsNone(res) + class NativeCoroutineOnMessageHandler(TestWebSocketHandler): def initialize(self, **kwargs): diff --git a/tornado/websocket.py b/tornado/websocket.py index 50665c694..ef13751ae 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -948,11 +948,15 @@ class WebSocketProtocol13(WebSocketProtocol): self.stream = handler._detach_stream() self.start_pinging() - open_result = self._run_callback( - handler.open, *handler.open_args, **handler.open_kwargs - ) - if open_result is not None: - await open_result + try: + open_result = handler.open(*handler.open_args, **handler.open_kwargs) + if open_result is not None: + await open_result + except Exception: + handler.log_exception(*sys.exc_info()) + self._abort() + return + await self._receive_frame_loop() def _parse_extensions_header(