From: Ben Darnell Date: Sun, 8 Sep 2013 01:14:26 +0000 (-0400) Subject: Allow preconstructed HTTPRequest objects in websocket_connect. X-Git-Tag: v3.2.0b1~85 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b5ec807edc83c8e7d1d12553d635ebe765e5c614;p=thirdparty%2Ftornado.git Allow preconstructed HTTPRequest objects in websocket_connect. In particular this allows for headers to be passed in to simulate browser authentication behavior. --- diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index 0c5a47479..7dc06c5d2 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -1,6 +1,5 @@ from tornado.concurrent import Future -from tornado import gen -from tornado.httpclient import HTTPError +from tornado.httpclient import HTTPError, HTTPRequest from tornado.log import gen_log from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog from tornado.web import Application, RequestHandler @@ -18,6 +17,11 @@ class EchoHandler(WebSocketHandler): self.close_future.set_result(None) +class HeaderHandler(WebSocketHandler): + def open(self): + self.write_message(self.request.headers.get('X-Test', '')) + + class NonWebSocketHandler(RequestHandler): def get(self): self.write('ok') @@ -29,6 +33,7 @@ class WebSocketTest(AsyncHTTPTestCase): return Application([ ('/echo', EchoHandler, dict(close_future=self.close_future)), ('/non_ws', NonWebSocketHandler), + ('/header', HeaderHandler), ]) @gen_test @@ -85,3 +90,12 @@ class WebSocketTest(AsyncHTTPTestCase): ws.write_message('world') ws.stream.close() yield self.close_future + + @gen_test + def test_websocket_headers(self): + # Ensure that arbitrary headers can be passed through websocket_connect. + ws = yield websocket_connect( + HTTPRequest('ws://localhost:%d/header' % self.get_http_port(), + headers={'X-Test': 'hello'})) + response = yield ws.read_message() + self.assertEqual(response, 'hello') diff --git a/tornado/websocket.py b/tornado/websocket.py index 676d21bf8..cd1ca7f51 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -33,7 +33,7 @@ import tornado.web from tornado.concurrent import TracebackFuture from tornado.escape import utf8, native_str -from tornado import httpclient +from tornado import httpclient, httputil from tornado.ioloop import IOLoop from tornado.iostream import StreamClosedError from tornado.log import gen_log, app_log @@ -862,7 +862,14 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None): """ if io_loop is None: io_loop = IOLoop.current() - request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout) + if isinstance(url, httpclient.HTTPRequest): + assert connect_timeout is None + request = url + # Copy and convert the headers dict/object (see comments in + # AsyncHTTPClient.fetch) + request.headers = httputil.HTTPHeaders(request.headers) + else: + request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout) request = httpclient._RequestProxy( request, httpclient.HTTPRequest._DEFAULTS) conn = WebSocketClientConnection(io_loop, request)