From: Kyle Kelley Date: Sun, 26 Jan 2014 00:07:44 +0000 (-0700) Subject: Modify origin handling & normalization. Add tests. X-Git-Tag: v4.0.0b1~35^2~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9963e737aaa1ff158897ef9f92c361894005e669;p=thirdparty%2Ftornado.git Modify origin handling & normalization. Add tests. --- diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index dae5d88b5..1873af3fa 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -189,23 +189,50 @@ class WebSocketTest(AsyncHTTPTestCase): yield self.close_future @gen_test - def test_check_origin_invalid(self): - '''Currently a failing test''' + def test_check_origin_valid2(self): port = self.get_http_port() url = 'ws://localhost:%d/echo' % port - headers = {'Origin': 'http://somewhereelse.com'} + headers = {'Origin': 'localhost:%d' % port} ws = yield websocket_connect(HTTPRequest(url, headers=headers), io_loop=self.io_loop) ws.write_message('hello') - response = yield ws.read_message() - - self.assertEqual(response, None) + self.assertEqual(response, 'hello') ws.close() yield self.close_future + @gen_test + def test_check_origin_invalid(self): + port = self.get_http_port() + + url = 'ws://localhost:%d/echo' % port + # Host is localhost, which should not be accessible from some other + # domain + headers = {'Origin': 'http://somewhereelse.com'} + + with self.assertRaises(HTTPError) as cm: + yield websocket_connect(HTTPRequest(url, headers=headers), + io_loop=self.io_loop) + + self.assertEqual(cm.exception.code, 403) + + @gen_test + def test_check_origin_invalid2(self): + port = self.get_http_port() + + url = 'ws://localhost:%d/echo' % port + # subdomains should be invalid by default + headers = {'Origin': 'http://subtenant.somewhereelse.com', + 'Host': 'subtenant2.somewhereelse.com'} + + with self.assertRaises(HTTPError) as cm: + yield websocket_connect(HTTPRequest(url, headers=headers), + io_loop=self.io_loop) + + self.assertEqual(cm.exception.code, 403) + class MaskFunctionMixin(object): # Subclasses should define self.mask(mask, data) diff --git a/tornado/websocket.py b/tornado/websocket.py index 22369c87c..1ba905b71 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -145,9 +145,34 @@ class WebSocketHandler(tornado.web.RequestHandler): self.stream.close() return + # Handle WebSocket Origin naming convention differences # The difference between version 8 and 13 is that in 8 the # client sends a "Sec-Websocket-Origin" header and in 13 it's # simply "Origin". + if "Origin" in self.request.headers: + origin = self.request.headers.get("Origin") + else: + origin = self.request.headers.get("Sec-Websocket-Origin", None) + + # If we have an origin, normalize + if(origin): + # Due to how stdlib's urlparse is implemented, urls without a // + # are interpreted to be paths (resulting in netloc being None) + if("//" not in origin): + origin = "//" + origin + parsed_origin = urlparse(origin) + origin = parsed_origin.netloc + origin = origin.lower() + + # If there was an origin header, check to make sure it matches + # according to check_origin + if not self.check_origin(origin): + self.stream.write(tornado.escape.utf8( + "HTTP/1.1 403 Cross Origin Websockets Disabled\r\n\r\n" + )) + self.stream.close() + return + if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"): self.ws_connection = WebSocketProtocol13(self) self.ws_connection.accept_connection() @@ -161,13 +186,6 @@ class WebSocketHandler(tornado.web.RequestHandler): "Sec-WebSocket-Version: 8\r\n\r\n")) self.stream.close() - if not self.check_origin(): - self.stream.write(tornado.escape.utf8( - "HTTP/1.1 403 Cross Origin Websockets Disabled\r\n\r\n" - )) - self.stream.close() - - def write_message(self, message, binary=False): """Sends the given message to the client of this Web Socket. @@ -264,37 +282,25 @@ class WebSocketHandler(tornado.web.RequestHandler): self.ws_connection.close(code, reason) self.ws_connection = None - def check_origin(self, allowed_origins=None): + def check_origin(self, origin): """Override to enable support for allowing alternate origins. - By default, this checks to see that requests that provide both a host - origin have the same origin and host This is a security protection - against cross site scripting attacks on browsers, - since WebSockets don't have CORS headers. + By default, this checks to see that the host matches the origin + provided. + + This is a security protection against cross site scripting attacks on + browsers, since WebSockets don't have CORS headers. - >>> self.check_origins(allowed_origins=['localhost']) + >>> self.check_origin(origin='localhost') + True """ - # Handle WebSocket Origin naming convention differences - # The difference between version 8 and 13 is that in 8 the - # client sends a "Sec-Websocket-Origin" header and in 13 it's - # simply "Origin". - if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8"): - origin_header = self.request.headers.get("Sec-Websocket-Origin") - else: - origin_header = self.request.headers.get("Origin") - - host = self.request.headers.get("Host") - - # If no header is provided, assume request is not coming from a browser - if(origin_header is None or host is None): + # When origin is None, assume it didn't come from a browser and we can + # pass it on + if origin is None: return True - parsed_origin = urlparse(origin_header) - origin = parsed_origin.netloc - - if allowed_origins and origin in allowed_origins: - return True + host = self.request.headers.get("Host") # Check to see that origin matches host directly, including ports return origin == host