From: Ben Darnell Date: Mon, 11 Feb 2019 02:02:25 +0000 (-0500) Subject: websocket: Catch errors in async open() correctly X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e38f2ae5e14ab5ea8dec22ec0bc306984b89c3ab;p=thirdparty%2Ftornado.git websocket: Catch errors in async open() correctly Previously if open() was a coroutine and raised an error, the connection would be left open. Fixes #2570 backport of dc354c57adac to 5.1 TestWebsocketHandler always requires close_future in this branch --- diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index ea7b1e4ce..1d5bfa68d 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -181,6 +181,18 @@ class OpenCoroutineHandler(TestWebSocketHandler): self.write_message('ok') +class ErrorInOpenHandler(TestWebSocketHandler): + def open(self): + raise Exception("boom") + + +class ErrorInAsyncOpenHandler(TestWebSocketHandler): + @gen.coroutine + def open(self): + yield gen.sleep(0.01) + raise Exception("boom") + + class WebSocketBaseTestCase(AsyncHTTPTestCase): @gen.coroutine def ws_connect(self, path, **kwargs): @@ -225,6 +237,10 @@ class WebSocketTest(WebSocketBaseTestCase): dict(close_future=self.close_future)), ('/open_coroutine', OpenCoroutineHandler, dict(close_future=self.close_future, test=self)), + ("/error_in_open", ErrorInOpenHandler, + dict(close_future=self.close_future)), + ("/error_in_async_open", ErrorInAsyncOpenHandler, + dict(close_future=self.close_future)), ], template_loader=DictLoader({ 'message.html': '{{ message }}', })) @@ -511,6 +527,20 @@ class WebSocketTest(WebSocketBaseTestCase): self.assertEqual(res, 'ok') yield self.close(ws) + @gen_test + def test_error_in_open(self): + with ExpectLog(app_log, "Uncaught exception"): + ws = yield self.ws_connect("/error_in_open") + res = yield ws.read_message() + self.assertIsNone(res) + + @gen_test + def test_error_in_async_open(self): + with ExpectLog(app_log, "Uncaught exception"): + ws = yield self.ws_connect("/error_in_async_open") + res = yield ws.read_message() + self.assertIsNone(res) + if sys.version_info >= (3, 5): NativeCoroutineOnMessageHandler = exec_test(globals(), locals(), """ diff --git a/tornado/websocket.py b/tornado/websocket.py index 0b994fc12..7b77850a6 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -751,10 +751,14 @@ class WebSocketProtocol13(WebSocketProtocol): self.stream = self.handler.stream self.start_pinging() - open_result = self._run_callback(self.handler.open, *self.handler.open_args, - **self.handler.open_kwargs) - if open_result is not None: - yield open_result + try: + open_result = self.handler.open(*self.handler.open_args, **self.handler.open_kwargs) + if open_result is not None: + yield open_result + except Exception: + self.handler.log_exception(*sys.exc_info()) + self._abort() + yield self._receive_frame_loop() def _parse_extensions_header(self, headers):