]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
httpserver: If no X-Scheme header, use the normal request value. 695/head
authorEvan Jones <ej@evanjones.ca>
Tue, 12 Mar 2013 20:23:52 +0000 (16:23 -0400)
committerEvan Jones <ej@evanjones.ca>
Tue, 12 Mar 2013 20:23:52 +0000 (16:23 -0400)
Previously, if xheaders is True and there are no X headers passed (e.g. when
developing locally), scheme was always "http". This makes in "http" or
"https", based on what was actually used for the request.

Add tests for the X-Scheme and X-Forwarded-Proto headers.

tornado/httpserver.py
tornado/test/httpserver_test.py

index 7c255f7a156c1bd437038d990ffa6bbf46e7a8a6..13ed97cb4e1ba1052d6c4e356b3253b1c276fe14 100644 (file)
@@ -393,26 +393,31 @@ class HTTPRequest(object):
         self.version = version
         self.headers = headers or httputil.HTTPHeaders()
         self.body = body or ""
+
+        # set remote IP and protocol
+        self.remote_ip = remote_ip
+        if protocol:
+            self.protocol = protocol
+        elif connection and isinstance(connection.stream,
+                                       iostream.SSLIOStream):
+            self.protocol = "https"
+        else:
+            self.protocol = "http"
+
+        # xheaders can override the defaults
         if connection and connection.xheaders:
             # Squid uses X-Forwarded-For, others use X-Real-Ip
-            self.remote_ip = self.headers.get(
-                "X-Real-Ip", self.headers.get("X-Forwarded-For", remote_ip))
-            if not netutil.is_valid_ip(self.remote_ip):
-                self.remote_ip = remote_ip
+            ip = self.headers.get(
+                "X-Real-Ip", self.headers.get("X-Forwarded-For", self.remote_ip))
+            if netutil.is_valid_ip(ip):
+                self.remote_ip = ip
             # AWS uses X-Forwarded-Proto
-            self.protocol = self.headers.get(
-                "X-Scheme", self.headers.get("X-Forwarded-Proto", protocol))
-            if self.protocol not in ("http", "https"):
-                self.protocol = "http"
-        else:
-            self.remote_ip = remote_ip
-            if protocol:
-                self.protocol = protocol
-            elif connection and isinstance(connection.stream,
-                                           iostream.SSLIOStream):
-                self.protocol = "https"
-            else:
-                self.protocol = "http"
+            proto = self.headers.get(
+                "X-Scheme", self.headers.get("X-Forwarded-Proto", self.protocol))
+            if proto in ("http", "https"):
+                self.protocol = proto
+
+
         self.host = host or self.headers.get("Host") or "127.0.0.1"
         self.files = files or {}
         self.connection = connection
index b3a771c36b00d8dca93f027ecc98c1aa59fa5ca6..123ae2e70eb5bc8c6f02f03964067ec7c34d7e3d 100644 (file)
@@ -346,14 +346,14 @@ class HTTPServerTest(AsyncHTTPTestCase):
 class XHeaderTest(HandlerBaseTestCase):
     class Handler(RequestHandler):
         def get(self):
-            self.write(dict(remote_ip=self.request.remote_ip))
+            self.write(dict(remote_ip=self.request.remote_ip,
+                remote_protocol=self.request.protocol))
 
     def get_httpserver_options(self):
         return dict(xheaders=True)
 
     def test_ip_headers(self):
-        self.assertEqual(self.fetch_json("/")["remote_ip"],
-                         "127.0.0.1")
+        self.assertEqual(self.fetch_json("/")["remote_ip"], "127.0.0.1")
 
         valid_ipv4 = {"X-Real-IP": "4.4.4.4"}
         self.assertEqual(
@@ -375,6 +375,45 @@ class XHeaderTest(HandlerBaseTestCase):
             self.fetch_json("/", headers=invalid_host)["remote_ip"],
             "127.0.0.1")
 
+    def test_scheme_headers(self):
+        self.assertEqual(self.fetch_json("/")["remote_protocol"], "http")
+
+        https_scheme = {"X-Scheme": "https"}
+        self.assertEqual(
+            self.fetch_json("/", headers=https_scheme)["remote_protocol"],
+            "https")
+
+        https_forwarded = {"X-Forwarded-Proto": "https"}
+        self.assertEqual(
+            self.fetch_json("/", headers=https_forwarded)["remote_protocol"],
+            "https")
+
+        bad_forwarded = {"X-Forwarded-Proto": "unknown"}
+        self.assertEqual(
+            self.fetch_json("/", headers=bad_forwarded)["remote_protocol"],
+            "http")
+
+
+class SSLXHeaderTest(AsyncHTTPSTestCase, HandlerBaseTestCase):
+    def get_app(self):
+        return Application([('/', XHeaderTest.Handler)])
+
+    def get_httpserver_options(self):
+        output = super(SSLXHeaderTest, self).get_httpserver_options()
+        output['xheaders'] = True
+        return output
+
+    def test_request_without_xprotocol(self):
+        self.assertEqual(self.fetch_json("/")["remote_protocol"], "https")
+
+        http_scheme = {"X-Scheme": "http"}
+        self.assertEqual(
+            self.fetch_json("/", headers=http_scheme)["remote_protocol"], "http")
+
+        bad_scheme = {"X-Scheme": "unknown"}
+        self.assertEqual(
+            self.fetch_json("/", headers=bad_scheme)["remote_protocol"], "https")
+
 
 class ManualProtocolTest(HandlerBaseTestCase):
     class Handler(RequestHandler):