-from tornado.testing import AsyncHTTPTestCase, gen_test
-from tornado.web import Application
-from tornado.websocket import WebSocketHandler, websocket_connect
+from tornado.httpclient import HTTPError
+from tornado.log import gen_log
+from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
+from tornado.web import Application, RequestHandler
+from tornado.websocket import WebSocketHandler, websocket_connect, WebSocketError
class EchoHandler(WebSocketHandler):
def on_message(self, message):
self.write_message(message, isinstance(message, bytes))
+class NonWebSocketHandler(RequestHandler):
+ def get(self):
+ self.write('ok')
+
class WebSocketTest(AsyncHTTPTestCase):
def get_app(self):
return Application([
('/echo', EchoHandler),
+ ('/non_ws', NonWebSocketHandler),
])
@gen_test
ws.read_message(self.stop)
response = self.wait().result()
self.assertEqual(response, 'hello')
-
+
@gen_test
- def test_websocket_fail(self):
- try:
- ws = yield websocket_connect(
- 'ws://localhost:%d/no_websock' % self.get_http_port(),
+ def test_websocket_http_fail(self):
+ with self.assertRaises(HTTPError) as cm:
+ yield websocket_connect(
+ 'ws://localhost:%d/notfound' % self.get_http_port(),
io_loop=self.io_loop)
- except:
- pass
- else:
- self.fail('Should\'ve caught an Exception')
+ self.assertEqual(cm.exception.code, 404)
+
+ @gen_test
+ def test_websocket_http_success(self):
+ with self.assertRaises(WebSocketError):
+ yield websocket_connect(
+ 'ws://localhost:%d/non_ws' % self.get_http_port(),
+ io_loop=self.io_loop)
+
+ @gen_test
+ def test_websocket_network_fail(self):
+ sock, port = bind_unused_port()
+ sock.close()
+ with self.assertRaises(HTTPError) as cm:
+ with ExpectLog(gen_log, ".*"):
+ yield websocket_connect(
+ 'ws://localhost:%d/' % port,
+ io_loop=self.io_loop,
+ connect_timeout=0.01)
+ self.assertEqual(cm.exception.code, 599)
xrange = range # py3
+class WebSocketError(Exception):
+ pass
+
+
class WebSocketHandler(tornado.web.RequestHandler):
"""Subclass this class to create a basic WebSocket handler.
})
super(WebSocketClientConnection, self).__init__(
- io_loop, None, request, lambda: None, lambda response: None,
+ io_loop, None, request, lambda: None, self._on_http_response,
104857600, Resolver(io_loop=io_loop))
def _on_close(self):
self.on_message(None)
- def _on_body(self, body):
- self.connect_future.set_exception(Exception('Could not connect.'))
+ def _on_http_response(self, response):
+ if not self.connect_future.done():
+ if response.error:
+ self.connect_future.set_exception(response.error)
+ else:
+ self.connect_future.set_exception(WebSocketError(
+ "Non-websocket response"))
def _handle_1xx(self, code):
assert code == 101
pass
-def websocket_connect(url, io_loop=None, callback=None):
+def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None):
"""Client-side websocket support.
Takes a url and returns a Future whose result is a
"""
if io_loop is None:
io_loop = IOLoop.current()
- request = httpclient.HTTPRequest(url)
+ request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout)
request = httpclient._RequestProxy(
request, httpclient.HTTPRequest._DEFAULTS)
conn = WebSocketClientConnection(io_loop, request)