from tornado.concurrent import Future
-from tornado import gen
-from tornado.httpclient import HTTPError
+from tornado.httpclient import HTTPError, HTTPRequest
from tornado.log import gen_log
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
from tornado.web import Application, RequestHandler
self.close_future.set_result(None)
+class HeaderHandler(WebSocketHandler):
+ def open(self):
+ self.write_message(self.request.headers.get('X-Test', ''))
+
+
class NonWebSocketHandler(RequestHandler):
def get(self):
self.write('ok')
return Application([
('/echo', EchoHandler, dict(close_future=self.close_future)),
('/non_ws', NonWebSocketHandler),
+ ('/header', HeaderHandler),
])
@gen_test
ws.write_message('world')
ws.stream.close()
yield self.close_future
+
+ @gen_test
+ def test_websocket_headers(self):
+ # Ensure that arbitrary headers can be passed through websocket_connect.
+ ws = yield websocket_connect(
+ HTTPRequest('ws://localhost:%d/header' % self.get_http_port(),
+ headers={'X-Test': 'hello'}))
+ response = yield ws.read_message()
+ self.assertEqual(response, 'hello')
from tornado.concurrent import TracebackFuture
from tornado.escape import utf8, native_str
-from tornado import httpclient
+from tornado import httpclient, httputil
from tornado.ioloop import IOLoop
from tornado.iostream import StreamClosedError
from tornado.log import gen_log, app_log
"""
if io_loop is None:
io_loop = IOLoop.current()
- request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout)
+ if isinstance(url, httpclient.HTTPRequest):
+ assert connect_timeout is None
+ request = url
+ # Copy and convert the headers dict/object (see comments in
+ # AsyncHTTPClient.fetch)
+ request.headers = httputil.HTTPHeaders(request.headers)
+ else:
+ request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout)
request = httpclient._RequestProxy(
request, httpclient.HTTPRequest._DEFAULTS)
conn = WebSocketClientConnection(io_loop, request)