]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
websocket: Catch errors in async open() correctly 2589/head
authorBen Darnell <ben@bendarnell.com>
Mon, 11 Feb 2019 02:02:25 +0000 (21:02 -0500)
committerPierce Lopez <pierce.lopez@gmail.com>
Sun, 24 Feb 2019 20:19:35 +0000 (15:19 -0500)
Previously if open() was a coroutine and raised an error, the
connection would be left open.

Fixes #2570
backport of dc354c57adac to 5.1

TestWebsocketHandler always requires close_future in this branch

tornado/test/websocket_test.py
tornado/websocket.py

index ea7b1e4ceff2b86134d70f605316f60c08006fcb..1d5bfa68da08415fffeb332524ff21facaa5ce39 100644 (file)
@@ -181,6 +181,18 @@ class OpenCoroutineHandler(TestWebSocketHandler):
         self.write_message('ok')
 
 
+class ErrorInOpenHandler(TestWebSocketHandler):
+    def open(self):
+        raise Exception("boom")
+
+
+class ErrorInAsyncOpenHandler(TestWebSocketHandler):
+    @gen.coroutine
+    def open(self):
+        yield gen.sleep(0.01)
+        raise Exception("boom")
+
+
 class WebSocketBaseTestCase(AsyncHTTPTestCase):
     @gen.coroutine
     def ws_connect(self, path, **kwargs):
@@ -225,6 +237,10 @@ class WebSocketTest(WebSocketBaseTestCase):
              dict(close_future=self.close_future)),
             ('/open_coroutine', OpenCoroutineHandler,
              dict(close_future=self.close_future, test=self)),
+            ("/error_in_open", ErrorInOpenHandler,
+             dict(close_future=self.close_future)),
+            ("/error_in_async_open", ErrorInAsyncOpenHandler,
+             dict(close_future=self.close_future)),
         ], template_loader=DictLoader({
             'message.html': '<b>{{ message }}</b>',
         }))
@@ -511,6 +527,20 @@ class WebSocketTest(WebSocketBaseTestCase):
         self.assertEqual(res, 'ok')
         yield self.close(ws)
 
+    @gen_test
+    def test_error_in_open(self):
+        with ExpectLog(app_log, "Uncaught exception"):
+            ws = yield self.ws_connect("/error_in_open")
+            res = yield ws.read_message()
+        self.assertIsNone(res)
+
+    @gen_test
+    def test_error_in_async_open(self):
+        with ExpectLog(app_log, "Uncaught exception"):
+            ws = yield self.ws_connect("/error_in_async_open")
+            res = yield ws.read_message()
+        self.assertIsNone(res)
+
 
 if sys.version_info >= (3, 5):
     NativeCoroutineOnMessageHandler = exec_test(globals(), locals(), """
index 0b994fc123c4a3ee88a23a95a50300873b1e2992..7b77850a6cd12e1c882254604380f5146a0a1d20 100644 (file)
@@ -751,10 +751,14 @@ class WebSocketProtocol13(WebSocketProtocol):
         self.stream = self.handler.stream
 
         self.start_pinging()
-        open_result = self._run_callback(self.handler.open, *self.handler.open_args,
-                                         **self.handler.open_kwargs)
-        if open_result is not None:
-            yield open_result
+        try:
+            open_result = self.handler.open(*self.handler.open_args, **self.handler.open_kwargs)
+            if open_result is not None:
+                yield open_result
+        except Exception:
+            self.handler.log_exception(*sys.exc_info())
+            self._abort()
+
         yield self._receive_frame_loop()
 
     def _parse_extensions_header(self, headers):