]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Make the websocket_connect error handling more comprehensive.
authorBen Darnell <ben@bendarnell.com>
Tue, 9 Apr 2013 03:11:08 +0000 (23:11 -0400)
committerBen Darnell <ben@bendarnell.com>
Tue, 9 Apr 2013 03:11:08 +0000 (23:11 -0400)
Now covers successful http responses that return a non-websocket body
and network errors that prevent a body from being returned.

Added a connect_timeout parameter to websocket_connect.

docs/releases/v3.0.1.rst
tornado/test/websocket_test.py
tornado/websocket.py

index b2ceda1ee69080d4a9ca0750b7ad58dfbfe231b6..b375dd4e1769d7676b25284de107b2c0eb069d80 100644 (file)
@@ -12,6 +12,8 @@ Apr 8, 2013
   as a (broken) test by ``nose``.
 * Work around a bug in Ubuntu 13.04 betas involving an incomplete backport
   of the `ssl.match_hostname` function.
+* `tornado.websocket.websocket_connect` now fails cleanly when it attempts
+  to connect to a non-websocket url.
 * `tornado.testing.LogTrapTestCase` once again works with byte strings
   on Python 2.
 * The ``request`` attribute of `tornado.httpclient.HTTPResponse` is
index eef5ed55838590293e969a1653f125ec3c862a41..65b02be142e27754b40cdeccd458f3194fdeec56 100644 (file)
@@ -1,17 +1,24 @@
-from tornado.testing import AsyncHTTPTestCase, gen_test
-from tornado.web import Application
-from tornado.websocket import WebSocketHandler, websocket_connect
+from tornado.httpclient import HTTPError
+from tornado.log import gen_log
+from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
+from tornado.web import Application, RequestHandler
+from tornado.websocket import WebSocketHandler, websocket_connect, WebSocketError
 
 
 class EchoHandler(WebSocketHandler):
     def on_message(self, message):
         self.write_message(message, isinstance(message, bytes))
 
+class NonWebSocketHandler(RequestHandler):
+    def get(self):
+        self.write('ok')
+
 
 class WebSocketTest(AsyncHTTPTestCase):
     def get_app(self):
         return Application([
             ('/echo', EchoHandler),
+            ('/non_ws', NonWebSocketHandler),
         ])
 
     @gen_test
@@ -32,14 +39,30 @@ class WebSocketTest(AsyncHTTPTestCase):
         ws.read_message(self.stop)
         response = self.wait().result()
         self.assertEqual(response, 'hello')
-    
+
     @gen_test
-    def test_websocket_fail(self):
-        try:
-            ws = yield websocket_connect(
-                'ws://localhost:%d/no_websock' % self.get_http_port(),
+    def test_websocket_http_fail(self):
+        with self.assertRaises(HTTPError) as cm:
+            yield websocket_connect(
+                'ws://localhost:%d/notfound' % self.get_http_port(),
                 io_loop=self.io_loop)
-        except:
-            pass
-        else:
-            self.fail('Should\'ve caught an Exception')
+        self.assertEqual(cm.exception.code, 404)
+
+    @gen_test
+    def test_websocket_http_success(self):
+        with self.assertRaises(WebSocketError):
+            yield websocket_connect(
+                'ws://localhost:%d/non_ws' % self.get_http_port(),
+                io_loop=self.io_loop)
+
+    @gen_test
+    def test_websocket_network_fail(self):
+        sock, port = bind_unused_port()
+        sock.close()
+        with self.assertRaises(HTTPError) as cm:
+            with ExpectLog(gen_log, ".*"):
+                yield websocket_connect(
+                    'ws://localhost:%d/' % port,
+                    io_loop=self.io_loop,
+                    connect_timeout=0.01)
+        self.assertEqual(cm.exception.code, 599)
index cd27d366ae2bf871473212e4aab316988fb24b98..4d06af965e6b984004ae1ad7c5d2398129251bc0 100644 (file)
@@ -46,6 +46,10 @@ except NameError:
     xrange = range  # py3
 
 
+class WebSocketError(Exception):
+    pass
+
+
 class WebSocketHandler(tornado.web.RequestHandler):
     """Subclass this class to create a basic WebSocket handler.
 
@@ -740,14 +744,19 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
         })
 
         super(WebSocketClientConnection, self).__init__(
-            io_loop, None, request, lambda: None, lambda response: None,
+            io_loop, None, request, lambda: None, self._on_http_response,
             104857600, Resolver(io_loop=io_loop))
 
     def _on_close(self):
         self.on_message(None)
 
-    def _on_body(self, body):
-        self.connect_future.set_exception(Exception('Could not connect.'))
+    def _on_http_response(self, response):
+        if not self.connect_future.done():
+            if response.error:
+                self.connect_future.set_exception(response.error)
+            else:
+                self.connect_future.set_exception(WebSocketError(
+                        "Non-websocket response"))
 
     def _handle_1xx(self, code):
         assert code == 101
@@ -798,7 +807,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
         pass
 
 
-def websocket_connect(url, io_loop=None, callback=None):
+def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None):
     """Client-side websocket support.
 
     Takes a url and returns a Future whose result is a
@@ -806,7 +815,7 @@ def websocket_connect(url, io_loop=None, callback=None):
     """
     if io_loop is None:
         io_loop = IOLoop.current()
-    request = httpclient.HTTPRequest(url)
+    request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout)
     request = httpclient._RequestProxy(
         request, httpclient.HTTPRequest._DEFAULTS)
     conn = WebSocketClientConnection(io_loop, request)