]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
websocket: Catch errors in async open() correctly 2581/head
authorBen Darnell <ben@bendarnell.com>
Sat, 2 Feb 2019 18:44:56 +0000 (13:44 -0500)
committerBen Darnell <ben@bendarnell.com>
Sat, 2 Feb 2019 18:44:56 +0000 (13:44 -0500)
Previously if open() was a coroutine and raised an error, the
connection would be left open.

Fixes #2570

tornado/test/websocket_test.py
tornado/websocket.py

index 1a5a27281f00b41c613831b2ac22fb392b650796..715ecf1e1c8f0a1b33d130aa1325d0f9b1a81e69 100644 (file)
@@ -187,6 +187,17 @@ class OpenCoroutineHandler(TestWebSocketHandler):
         self.write_message("ok")
 
 
+class ErrorInOpenHandler(TestWebSocketHandler):
+    def open(self):
+        raise Exception("boom")
+
+
+class ErrorInAsyncOpenHandler(TestWebSocketHandler):
+    async def open(self):
+        await asyncio.sleep(0)
+        raise Exception("boom")
+
+
 class WebSocketBaseTestCase(AsyncHTTPTestCase):
     @gen.coroutine
     def ws_connect(self, path, **kwargs):
@@ -245,6 +256,8 @@ class WebSocketTest(WebSocketBaseTestCase):
                     OpenCoroutineHandler,
                     dict(close_future=self.close_future, test=self),
                 ),
+                ("/error_in_open", ErrorInOpenHandler),
+                ("/error_in_async_open", ErrorInAsyncOpenHandler),
             ],
             template_loader=DictLoader({"message.html": "<b>{{ message }}</b>"}),
         )
@@ -535,6 +548,20 @@ class WebSocketTest(WebSocketBaseTestCase):
         res = yield ws.read_message()
         self.assertEqual(res, "ok")
 
+    @gen_test
+    def test_error_in_open(self):
+        with ExpectLog(app_log, "Uncaught exception"):
+            ws = yield self.ws_connect("/error_in_open")
+            res = yield ws.read_message()
+        self.assertIsNone(res)
+
+    @gen_test
+    def test_error_in_async_open(self):
+        with ExpectLog(app_log, "Uncaught exception"):
+            ws = yield self.ws_connect("/error_in_async_open")
+            res = yield ws.read_message()
+        self.assertIsNone(res)
+
 
 class NativeCoroutineOnMessageHandler(TestWebSocketHandler):
     def initialize(self, **kwargs):
index 50665c694db16afd1765d9735e81e6669309ec37..ef13751ae19299300d13a073cb9298d0811b0511 100644 (file)
@@ -948,11 +948,15 @@ class WebSocketProtocol13(WebSocketProtocol):
         self.stream = handler._detach_stream()
 
         self.start_pinging()
-        open_result = self._run_callback(
-            handler.open, *handler.open_args, **handler.open_kwargs
-        )
-        if open_result is not None:
-            await open_result
+        try:
+            open_result = handler.open(*handler.open_args, **handler.open_kwargs)
+            if open_result is not None:
+                await open_result
+        except Exception:
+            handler.log_exception(*sys.exc_info())
+            self._abort()
+            return
+
         await self._receive_frame_loop()
 
     def _parse_extensions_header(