yield self.close_future
@gen_test
- def test_check_origin_invalid(self):
- '''Currently a failing test'''
+ def test_check_origin_valid2(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
- headers = {'Origin': 'http://somewhereelse.com'}
+ headers = {'Origin': 'localhost:%d' % port}
ws = yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
ws.write_message('hello')
-
response = yield ws.read_message()
-
- self.assertEqual(response, None)
+ self.assertEqual(response, 'hello')
ws.close()
yield self.close_future
+ @gen_test
+ def test_check_origin_invalid(self):
+ port = self.get_http_port()
+
+ url = 'ws://localhost:%d/echo' % port
+ # Host is localhost, which should not be accessible from some other
+ # domain
+ headers = {'Origin': 'http://somewhereelse.com'}
+
+ with self.assertRaises(HTTPError) as cm:
+ yield websocket_connect(HTTPRequest(url, headers=headers),
+ io_loop=self.io_loop)
+
+ self.assertEqual(cm.exception.code, 403)
+
+ @gen_test
+ def test_check_origin_invalid2(self):
+ port = self.get_http_port()
+
+ url = 'ws://localhost:%d/echo' % port
+ # subdomains should be invalid by default
+ headers = {'Origin': 'http://subtenant.somewhereelse.com',
+ 'Host': 'subtenant2.somewhereelse.com'}
+
+ with self.assertRaises(HTTPError) as cm:
+ yield websocket_connect(HTTPRequest(url, headers=headers),
+ io_loop=self.io_loop)
+
+ self.assertEqual(cm.exception.code, 403)
+
class MaskFunctionMixin(object):
# Subclasses should define self.mask(mask, data)
self.stream.close()
return
+ # Handle WebSocket Origin naming convention differences
# The difference between version 8 and 13 is that in 8 the
# client sends a "Sec-Websocket-Origin" header and in 13 it's
# simply "Origin".
+ if "Origin" in self.request.headers:
+ origin = self.request.headers.get("Origin")
+ else:
+ origin = self.request.headers.get("Sec-Websocket-Origin", None)
+
+ # If we have an origin, normalize
+ if(origin):
+ # Due to how stdlib's urlparse is implemented, urls without a //
+ # are interpreted to be paths (resulting in netloc being None)
+ if("//" not in origin):
+ origin = "//" + origin
+ parsed_origin = urlparse(origin)
+ origin = parsed_origin.netloc
+ origin = origin.lower()
+
+ # If there was an origin header, check to make sure it matches
+ # according to check_origin
+ if not self.check_origin(origin):
+ self.stream.write(tornado.escape.utf8(
+ "HTTP/1.1 403 Cross Origin Websockets Disabled\r\n\r\n"
+ ))
+ self.stream.close()
+ return
+
if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
self.ws_connection = WebSocketProtocol13(self)
self.ws_connection.accept_connection()
"Sec-WebSocket-Version: 8\r\n\r\n"))
self.stream.close()
- if not self.check_origin():
- self.stream.write(tornado.escape.utf8(
- "HTTP/1.1 403 Cross Origin Websockets Disabled\r\n\r\n"
- ))
- self.stream.close()
-
-
def write_message(self, message, binary=False):
"""Sends the given message to the client of this Web Socket.
self.ws_connection.close(code, reason)
self.ws_connection = None
- def check_origin(self, allowed_origins=None):
+ def check_origin(self, origin):
"""Override to enable support for allowing alternate origins.
- By default, this checks to see that requests that provide both a host
- origin have the same origin and host This is a security protection
- against cross site scripting attacks on browsers,
- since WebSockets don't have CORS headers.
+ By default, this checks to see that the host matches the origin
+ provided.
+
+ This is a security protection against cross site scripting attacks on
+ browsers, since WebSockets don't have CORS headers.
- >>> self.check_origins(allowed_origins=['localhost'])
+ >>> self.check_origin(origin='localhost')
+ True
"""
- # Handle WebSocket Origin naming convention differences
- # The difference between version 8 and 13 is that in 8 the
- # client sends a "Sec-Websocket-Origin" header and in 13 it's
- # simply "Origin".
- if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8"):
- origin_header = self.request.headers.get("Sec-Websocket-Origin")
- else:
- origin_header = self.request.headers.get("Origin")
-
- host = self.request.headers.get("Host")
-
- # If no header is provided, assume request is not coming from a browser
- if(origin_header is None or host is None):
+ # When origin is None, assume it didn't come from a browser and we can
+ # pass it on
+ if origin is None:
return True
- parsed_origin = urlparse(origin_header)
- origin = parsed_origin.netloc
-
- if allowed_origins and origin in allowed_origins:
- return True
+ host = self.request.headers.get("Host")
# Check to see that origin matches host directly, including ports
return origin == host