From: Evan Jones Date: Tue, 12 Mar 2013 20:23:52 +0000 (-0400) Subject: httpserver: If no X-Scheme header, use the normal request value. X-Git-Tag: v3.0.0~37^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fpull%2F695%2Fhead;p=thirdparty%2Ftornado.git httpserver: If no X-Scheme header, use the normal request value. Previously, if xheaders is True and there are no X headers passed (e.g. when developing locally), scheme was always "http". This makes in "http" or "https", based on what was actually used for the request. Add tests for the X-Scheme and X-Forwarded-Proto headers. --- diff --git a/tornado/httpserver.py b/tornado/httpserver.py index 7c255f7a1..13ed97cb4 100644 --- a/tornado/httpserver.py +++ b/tornado/httpserver.py @@ -393,26 +393,31 @@ class HTTPRequest(object): self.version = version self.headers = headers or httputil.HTTPHeaders() self.body = body or "" + + # set remote IP and protocol + self.remote_ip = remote_ip + if protocol: + self.protocol = protocol + elif connection and isinstance(connection.stream, + iostream.SSLIOStream): + self.protocol = "https" + else: + self.protocol = "http" + + # xheaders can override the defaults if connection and connection.xheaders: # Squid uses X-Forwarded-For, others use X-Real-Ip - self.remote_ip = self.headers.get( - "X-Real-Ip", self.headers.get("X-Forwarded-For", remote_ip)) - if not netutil.is_valid_ip(self.remote_ip): - self.remote_ip = remote_ip + ip = self.headers.get( + "X-Real-Ip", self.headers.get("X-Forwarded-For", self.remote_ip)) + if netutil.is_valid_ip(ip): + self.remote_ip = ip # AWS uses X-Forwarded-Proto - self.protocol = self.headers.get( - "X-Scheme", self.headers.get("X-Forwarded-Proto", protocol)) - if self.protocol not in ("http", "https"): - self.protocol = "http" - else: - self.remote_ip = remote_ip - if protocol: - self.protocol = protocol - elif connection and isinstance(connection.stream, - iostream.SSLIOStream): - self.protocol = "https" - else: - self.protocol = "http" + proto = self.headers.get( + "X-Scheme", self.headers.get("X-Forwarded-Proto", self.protocol)) + if proto in ("http", "https"): + self.protocol = proto + + self.host = host or self.headers.get("Host") or "127.0.0.1" self.files = files or {} self.connection = connection diff --git a/tornado/test/httpserver_test.py b/tornado/test/httpserver_test.py index b3a771c36..123ae2e70 100644 --- a/tornado/test/httpserver_test.py +++ b/tornado/test/httpserver_test.py @@ -346,14 +346,14 @@ class HTTPServerTest(AsyncHTTPTestCase): class XHeaderTest(HandlerBaseTestCase): class Handler(RequestHandler): def get(self): - self.write(dict(remote_ip=self.request.remote_ip)) + self.write(dict(remote_ip=self.request.remote_ip, + remote_protocol=self.request.protocol)) def get_httpserver_options(self): return dict(xheaders=True) def test_ip_headers(self): - self.assertEqual(self.fetch_json("/")["remote_ip"], - "127.0.0.1") + self.assertEqual(self.fetch_json("/")["remote_ip"], "127.0.0.1") valid_ipv4 = {"X-Real-IP": "4.4.4.4"} self.assertEqual( @@ -375,6 +375,45 @@ class XHeaderTest(HandlerBaseTestCase): self.fetch_json("/", headers=invalid_host)["remote_ip"], "127.0.0.1") + def test_scheme_headers(self): + self.assertEqual(self.fetch_json("/")["remote_protocol"], "http") + + https_scheme = {"X-Scheme": "https"} + self.assertEqual( + self.fetch_json("/", headers=https_scheme)["remote_protocol"], + "https") + + https_forwarded = {"X-Forwarded-Proto": "https"} + self.assertEqual( + self.fetch_json("/", headers=https_forwarded)["remote_protocol"], + "https") + + bad_forwarded = {"X-Forwarded-Proto": "unknown"} + self.assertEqual( + self.fetch_json("/", headers=bad_forwarded)["remote_protocol"], + "http") + + +class SSLXHeaderTest(AsyncHTTPSTestCase, HandlerBaseTestCase): + def get_app(self): + return Application([('/', XHeaderTest.Handler)]) + + def get_httpserver_options(self): + output = super(SSLXHeaderTest, self).get_httpserver_options() + output['xheaders'] = True + return output + + def test_request_without_xprotocol(self): + self.assertEqual(self.fetch_json("/")["remote_protocol"], "https") + + http_scheme = {"X-Scheme": "http"} + self.assertEqual( + self.fetch_json("/", headers=http_scheme)["remote_protocol"], "http") + + bad_scheme = {"X-Scheme": "unknown"} + self.assertEqual( + self.fetch_json("/", headers=bad_scheme)["remote_protocol"], "https") + class ManualProtocolTest(HandlerBaseTestCase): class Handler(RequestHandler):