class HeaderHandler(TestWebSocketHandler):
def open(self):
+ try:
+ # In a websocket context, many RequestHandler methods
+ # raise RuntimeErrors.
+ self.set_status(503)
+ raise Exception("did not get expected exception")
+ except RuntimeError:
+ pass
self.write_message(self.request.headers.get('X-Test', ''))
dict(close_future=self.close_future)),
])
+ def test_http_request(self):
+ # WS server, HTTP client.
+ response = self.fetch('/echo')
+ self.assertEqual(response.code, 400)
+
@gen_test
def test_websocket_gen(self):
ws = yield websocket_connect(
self.ws_connection = None
self.close_code = None
self.close_reason = None
+ self.stream = None
@tornado.web.asynchronous
def get(self, *args, **kwargs):
self.open_args = args
self.open_kwargs = kwargs
- self.stream = self.request.connection.detach()
- self.stream.set_close_callback(self.on_connection_close)
-
# Upgrade header should be present and should be equal to WebSocket
if self.request.headers.get("Upgrade", "").lower() != 'websocket':
- self.stream.write(tornado.escape.utf8(
- "HTTP/1.1 400 Bad Request\r\n\r\n"
- "Can \"Upgrade\" only to \"WebSocket\"."
- ))
- self.stream.close()
+ self.set_status(400)
+ self.finish("Can \"Upgrade\" only to \"WebSocket\".")
return
# Connection header should be upgrade. Some proxy servers/load balancers
headers = self.request.headers
connection = map(lambda s: s.strip().lower(), headers.get("Connection", "").split(","))
if 'upgrade' not in connection:
- self.stream.write(tornado.escape.utf8(
- "HTTP/1.1 400 Bad Request\r\n\r\n"
- "\"Connection\" must be \"Upgrade\"."
- ))
- self.stream.close()
+ self.set_status(400)
+ self.finish("\"Connection\" must be \"Upgrade\".")
return
# Handle WebSocket Origin naming convention differences
# according to check_origin. When the origin is None, we assume it
# did not come from a browser and that it can be passed on.
if origin is not None and not self.check_origin(origin):
- self.stream.write(tornado.escape.utf8(
- "HTTP/1.1 403 Cross Origin Websockets Disabled\r\n\r\n"
- ))
- self.stream.close()
+ self.set_status(403)
+ self.finish("Cross origin websockets not allowed")
return
+ self.stream = self.request.connection.detach()
+ self.stream.set_close_callback(self.on_connection_close)
+
if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
self.ws_connection = WebSocketProtocol13(self)
self.ws_connection.accept_connection()
"""
return "wss" if self.request.protocol == "https" else "ws"
- def _not_supported(self, *args, **kwargs):
- raise Exception("Method not supported for Web Sockets")
-
def on_connection_close(self):
if self.ws_connection:
self.ws_connection.on_connection_close()
self.on_close()
+def _wrap_method(method):
+ def _disallow_for_websocket(self, *args, **kwargs):
+ if self.stream is None:
+ method(self, *args, **kwargs)
+ else:
+ raise RuntimeError("Method not supported for Web Sockets")
+ return _disallow_for_websocket
for method in ["write", "redirect", "set_header", "send_error", "set_cookie",
"set_status", "flush", "finish"]:
- setattr(WebSocketHandler, method, WebSocketHandler._not_supported)
+ setattr(WebSocketHandler, method,
+ _wrap_method(getattr(WebSocketHandler, method)))
class WebSocketProtocol(object):