]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Modify origin handling & normalization. Add tests.
authorKyle Kelley <kyle.kelley@rackspace.com>
Sun, 26 Jan 2014 00:07:44 +0000 (17:07 -0700)
committerKyle Kelley <kyle.kelley@rackspace.com>
Thu, 8 May 2014 18:42:21 +0000 (13:42 -0500)
tornado/test/websocket_test.py
tornado/websocket.py

index dae5d88b5a981fafd18c562cfc3b4054884c1da1..1873af3fa21e34e30a5ce0d5424f700724b9f94b 100644 (file)
@@ -189,23 +189,50 @@ class WebSocketTest(AsyncHTTPTestCase):
         yield self.close_future
 
     @gen_test
-    def test_check_origin_invalid(self):
-        '''Currently a failing test'''
+    def test_check_origin_valid2(self):
         port = self.get_http_port()
 
         url = 'ws://localhost:%d/echo' % port
-        headers = {'Origin': 'http://somewhereelse.com'}
+        headers = {'Origin': 'localhost:%d' % port}
 
         ws = yield websocket_connect(HTTPRequest(url, headers=headers),
             io_loop=self.io_loop)
         ws.write_message('hello')
-
         response = yield ws.read_message()
-
-        self.assertEqual(response, None)
+        self.assertEqual(response, 'hello')
         ws.close()
         yield self.close_future
 
+    @gen_test
+    def test_check_origin_invalid(self):
+        port = self.get_http_port()
+
+        url = 'ws://localhost:%d/echo' % port
+        # Host is localhost, which should not be accessible from some other
+        # domain
+        headers = {'Origin': 'http://somewhereelse.com'}
+
+        with self.assertRaises(HTTPError) as cm:
+            yield websocket_connect(HTTPRequest(url, headers=headers),
+                                    io_loop=self.io_loop)
+
+        self.assertEqual(cm.exception.code, 403)
+
+    @gen_test
+    def test_check_origin_invalid2(self):
+        port = self.get_http_port()
+
+        url = 'ws://localhost:%d/echo' % port
+        # subdomains should be invalid by default
+        headers = {'Origin': 'http://subtenant.somewhereelse.com',
+                   'Host': 'subtenant2.somewhereelse.com'}
+
+        with self.assertRaises(HTTPError) as cm:
+            yield websocket_connect(HTTPRequest(url, headers=headers),
+                                    io_loop=self.io_loop)
+
+        self.assertEqual(cm.exception.code, 403)
+
 
 class MaskFunctionMixin(object):
     # Subclasses should define self.mask(mask, data)
index 22369c87cf02dccb087d6b40f3bb0efc871f330d..1ba905b71bf1365c944a9204e3a41fa6f2e85a08 100644 (file)
@@ -145,9 +145,34 @@ class WebSocketHandler(tornado.web.RequestHandler):
             self.stream.close()
             return
 
+        # Handle WebSocket Origin naming convention differences
         # The difference between version 8 and 13 is that in 8 the
         # client sends a "Sec-Websocket-Origin" header and in 13 it's
         # simply "Origin".
+        if "Origin" in self.request.headers:
+            origin = self.request.headers.get("Origin")
+        else:
+            origin = self.request.headers.get("Sec-Websocket-Origin", None)
+
+        # If we have an origin, normalize
+        if(origin):
+            # Due to how stdlib's urlparse is implemented, urls without a //
+            # are interpreted to be paths (resulting in netloc being None)
+            if("//" not in origin):
+                origin = "//" + origin
+            parsed_origin = urlparse(origin)
+            origin = parsed_origin.netloc
+            origin = origin.lower()
+
+        # If there was an origin header, check to make sure it matches
+        # according to check_origin
+        if not self.check_origin(origin):
+            self.stream.write(tornado.escape.utf8(
+                "HTTP/1.1 403 Cross Origin Websockets Disabled\r\n\r\n"
+            ))
+            self.stream.close()
+            return
+
         if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
             self.ws_connection = WebSocketProtocol13(self)
             self.ws_connection.accept_connection()
@@ -161,13 +186,6 @@ class WebSocketHandler(tornado.web.RequestHandler):
                 "Sec-WebSocket-Version: 8\r\n\r\n"))
             self.stream.close()
 
-        if not self.check_origin():
-            self.stream.write(tornado.escape.utf8(
-                "HTTP/1.1 403 Cross Origin Websockets Disabled\r\n\r\n"
-            ))
-            self.stream.close()
-
-
 
     def write_message(self, message, binary=False):
         """Sends the given message to the client of this Web Socket.
@@ -264,37 +282,25 @@ class WebSocketHandler(tornado.web.RequestHandler):
             self.ws_connection.close(code, reason)
             self.ws_connection = None
 
-    def check_origin(self, allowed_origins=None):
+    def check_origin(self, origin):
         """Override to enable support for allowing alternate origins.
         
-        By default, this checks to see that requests that provide both a host
-        origin have the same origin and host This is a security protection
-        against cross site scripting attacks on browsers,
-        since WebSockets don't have CORS headers.
+        By default, this checks to see that the host matches the origin
+        provided.
+
+        This is a security protection against cross site scripting attacks on
+        browsers, since WebSockets don't have CORS headers.
         
-        >>> self.check_origins(allowed_origins=['localhost'])
+        >>> self.check_origin(origin='localhost')
+        True
         
         """
-        # Handle WebSocket Origin naming convention differences
-        # The difference between version 8 and 13 is that in 8 the
-        # client sends a "Sec-Websocket-Origin" header and in 13 it's
-        # simply "Origin".
-        if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8"):
-            origin_header = self.request.headers.get("Sec-Websocket-Origin")
-        else:
-            origin_header = self.request.headers.get("Origin")
-
-        host = self.request.headers.get("Host")
-
-        # If no header is provided, assume request is not coming from a browser
-        if(origin_header is None or host is None):
+        # When origin is None, assume it didn't come from a browser and we can
+        # pass it on
+        if origin is None:
             return True
 
-        parsed_origin = urlparse(origin_header)
-        origin = parsed_origin.netloc
-
-        if allowed_origins and origin in allowed_origins:
-            return True
+        host = self.request.headers.get("Host")
 
         # Check to see that origin matches host directly, including ports
         return origin == host