]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Allow preconstructed HTTPRequest objects in websocket_connect.
authorBen Darnell <ben@bendarnell.com>
Sun, 8 Sep 2013 01:14:26 +0000 (21:14 -0400)
committerBen Darnell <ben@bendarnell.com>
Sun, 8 Sep 2013 01:14:26 +0000 (21:14 -0400)
In particular this allows for headers to be passed in to simulate
browser authentication behavior.

tornado/test/websocket_test.py
tornado/websocket.py

index 0c5a474790b6bdd708699d308c6593705750bc62..7dc06c5d261ceb9cf739cf8a44ef2cb79962ea14 100644 (file)
@@ -1,6 +1,5 @@
 from tornado.concurrent import Future
-from tornado import gen
-from tornado.httpclient import HTTPError
+from tornado.httpclient import HTTPError, HTTPRequest
 from tornado.log import gen_log
 from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
 from tornado.web import Application, RequestHandler
@@ -18,6 +17,11 @@ class EchoHandler(WebSocketHandler):
         self.close_future.set_result(None)
 
 
+class HeaderHandler(WebSocketHandler):
+    def open(self):
+        self.write_message(self.request.headers.get('X-Test', ''))
+
+
 class NonWebSocketHandler(RequestHandler):
     def get(self):
         self.write('ok')
@@ -29,6 +33,7 @@ class WebSocketTest(AsyncHTTPTestCase):
         return Application([
             ('/echo', EchoHandler, dict(close_future=self.close_future)),
             ('/non_ws', NonWebSocketHandler),
+            ('/header', HeaderHandler),
         ])
 
     @gen_test
@@ -85,3 +90,12 @@ class WebSocketTest(AsyncHTTPTestCase):
         ws.write_message('world')
         ws.stream.close()
         yield self.close_future
+
+    @gen_test
+    def test_websocket_headers(self):
+        # Ensure that arbitrary headers can be passed through websocket_connect.
+        ws = yield websocket_connect(
+            HTTPRequest('ws://localhost:%d/header' % self.get_http_port(),
+                        headers={'X-Test': 'hello'}))
+        response = yield ws.read_message()
+        self.assertEqual(response, 'hello')
index 676d21bf86a8b73bf2f500957eeecc7340367ee0..cd1ca7f51d59c9d7d618f9e78d071ef0dd7c8ccd 100644 (file)
@@ -33,7 +33,7 @@ import tornado.web
 
 from tornado.concurrent import TracebackFuture
 from tornado.escape import utf8, native_str
-from tornado import httpclient
+from tornado import httpclient, httputil
 from tornado.ioloop import IOLoop
 from tornado.iostream import StreamClosedError
 from tornado.log import gen_log, app_log
@@ -862,7 +862,14 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None):
     """
     if io_loop is None:
         io_loop = IOLoop.current()
-    request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout)
+    if isinstance(url, httpclient.HTTPRequest):
+        assert connect_timeout is None
+        request = url
+        # Copy and convert the headers dict/object (see comments in
+        # AsyncHTTPClient.fetch)
+        request.headers = httputil.HTTPHeaders(request.headers)
+    else:
+        request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout)
     request = httpclient._RequestProxy(
         request, httpclient.HTTPRequest._DEFAULTS)
     conn = WebSocketClientConnection(io_loop, request)