]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Rewrite IP check to use getaddrinfo instead of inet_pton (more portable).
authorBen Darnell <ben@bendarnell.com>
Mon, 2 Jan 2012 06:54:36 +0000 (22:54 -0800)
committerBen Darnell <ben@bendarnell.com>
Mon, 2 Jan 2012 06:54:36 +0000 (22:54 -0800)
Closes #392.

tornado/httpserver.py
tornado/test/httpserver_test.py

index 10723590ab5c25311dbe53ba534204a9494abf5e..e692ba8a9ca34691fb389a1bc6e2cb7fec800635 100644 (file)
@@ -362,7 +362,7 @@ class HTTPRequest(object):
             # Squid uses X-Forwarded-For, others use X-Real-Ip
             self.remote_ip = self.headers.get(
                 "X-Real-Ip", self.headers.get("X-Forwarded-For", remote_ip))
-            if not self.__valid_ip(self.remote_ip):
+            if not self._valid_ip(self.remote_ip):
                 self.remote_ip = remote_ip
             # AWS uses X-Forwarded-Proto
             self.protocol = self.headers.get(
@@ -460,13 +460,15 @@ class HTTPRequest(object):
         return "%s(%s, headers=%s)" % (
             self.__class__.__name__, args, dict(self.headers))
 
-    def __valid_ip(self, ip):
+    def _valid_ip(self, ip):
         try:
-            address = socket.inet_pton(socket.AF_INET, ip)
-        except socket.error:
-            try:
-                address = socket.inet_pton(socket.AF_INET6, ip)
-            except socket.error:
+            res = socket.getaddrinfo(ip, 0, socket.AF_UNSPEC,
+                                     socket.SOCK_STREAM,
+                                     0, socket.AI_NUMERICHOST)
+            return bool(res)
+        except socket.gaierror, e:
+            if e.args[0] == socket.EAI_NONAME:
                 return False
-
+            raise
         return True
+
index 5358ee29d12e264f42e8a7573cb82103fc05c774..036348522476e40e4979790ed9495bcb16e13c9c 100644 (file)
@@ -20,6 +20,15 @@ try:
 except ImportError:
     ssl = None
 
+class HandlerBaseTestCase(AsyncHTTPTestCase, LogTrapTestCase):
+    def get_app(self):
+        return Application([('/', self.__class__.Handler)])
+
+    def fetch_json(self, *args, **kwargs):
+        response = self.fetch(*args, **kwargs)
+        response.rethrow()
+        return json_decode(response.body)
+
 class HelloWorldRequestHandler(RequestHandler):
     def initialize(self, protocol="http"):
         self.expected_protocol = protocol
@@ -236,6 +245,39 @@ class HTTPServerTest(AsyncHTTPTestCase, LogTrapTestCase):
         data = json_decode(response.body)
         self.assertEqual(data, {})
 
+class XHeaderTest(HandlerBaseTestCase):
+    class Handler(RequestHandler):
+        def get(self):
+            self.write(dict(remote_ip=self.request.remote_ip))
+
+    def get_httpserver_options(self):
+        return dict(xheaders=True)
+
+    def test_ip_headers(self):
+        self.assertEqual(self.fetch_json("/")["remote_ip"],
+                         "127.0.0.1")
+
+        valid_ipv4 = {"X-Real-IP": "4.4.4.4"}
+        self.assertEqual(
+            self.fetch_json("/", headers=valid_ipv4)["remote_ip"],
+            "4.4.4.4")
+
+        valid_ipv6 = {"X-Real-IP": "2620:0:1cfe:face:b00c::3"}
+        self.assertEqual(
+            self.fetch_json("/", headers=valid_ipv6)["remote_ip"],
+            "2620:0:1cfe:face:b00c::3")
+
+        invalid_chars = {"X-Real-IP": "4.4.4.4<script>"}
+        self.assertEqual(
+            self.fetch_json("/", headers=invalid_chars)["remote_ip"],
+            "127.0.0.1")
+
+        invalid_host = {"X-Real-IP": "www.google.com"}
+        self.assertEqual(
+            self.fetch_json("/", headers=invalid_host)["remote_ip"],
+            "127.0.0.1")
+
+
 class UnixSocketTest(AsyncTestCase, LogTrapTestCase):
     """HTTPServers can listen on Unix sockets too.