"Sec-WebSocket-Version: 8\r\n\r\n"))
self.stream.close()
- # Assume cross origin is disallowed by default, while allowing users to
- # choose
- if kwargs.get('allow_cross_origin', False):
- pass
- # Check that the host and origin match
- elif not self.same_origin():
+ if not self.check_origin():
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 403 Cross Origin Websockets Disabled\r\n\r\n"
))
self.stream.close()
- def same_origin(self):
- """Check to see that origin and host match in the headers."""
- # The difference between version 8 and 13 is that in 8 the
- # client sends a "Sec-Websocket-Origin" header and in 13 it's
- # simply "Origin".
- if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8"):
- origin_header = self.request.headers.get("Sec-Websocket-Origin")
- else:
- origin_header = self.request.headers.get("Origin")
-
- host = self.request.headers.get("Host")
-
- # If no header is provided, assume we can't verify origin
- if(origin_header is None or host is None):
- return False
-
- parsed_origin = urlparse(origin_header)
- origin = parsed_origin.netloc
-
- # Check to see that origin matches host directly, including ports
- return origin == host
def write_message(self, message, binary=False):
"""Sends the given message to the client of this Web Socket.
self.ws_connection.close(code, reason)
self.ws_connection = None
+ def check_origin(self, allowed_origins=None):
+ """Override to enable support for allowing alternate origins.
+
+ By default, this checks to see that requests that provide both a host
+ origin have the same origin and host This is a security protection
+ against cross site scripting attacks on browsers,
+ since WebSockets don't have CORS headers."""
+
+ # Handle WebSocket Origin naming convention differences
+ # The difference between version 8 and 13 is that in 8 the
+ # client sends a "Sec-Websocket-Origin" header and in 13 it's
+ # simply "Origin".
+ if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8"):
+ origin_header = self.request.headers.get("Sec-Websocket-Origin")
+ else:
+ origin_header = self.request.headers.get("Origin")
+
+ host = self.request.headers.get("Host")
+
+ # If no header is provided, assume request is not coming from a browser
+ if(origin_header is None or host is None):
+ return True
+
+ parsed_origin = urlparse(origin_header)
+ origin = parsed_origin.netloc
+
+ if origin in allowed_origins:
+ return True
+
+ # Check to see that origin matches host directly, including ports
+ return origin == host
+
def allow_draft76(self):
"""Override to enable support for the older "draft76" protocol.