]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add family argument to OverrideResolver->resolve method (#2201)
authorBoris <ttyv00@gmail.com>
Tue, 26 Dec 2017 23:05:28 +0000 (02:05 +0300)
committerBen Darnell <ben@bendarnell.com>
Tue, 26 Dec 2017 23:05:28 +0000 (18:05 -0500)
tornado/netutil.py
tornado/test/netutil_test.py

index caaa09090cc55cc20c10724e2090a2e0d2bff968..45d9e36c0a8e019e20b271bb268541d3e4a151a2 100644 (file)
@@ -464,7 +464,8 @@ class OverrideResolver(Resolver):
     This can be used to make local DNS changes (e.g. for testing)
     without modifying system-wide settings.
 
-    The mapping can contain either host strings or host-port pairs.
+    The mapping can contain either host strings or host-port pairs or
+    host-port-family triplets.
     """
     def initialize(self, resolver, mapping):
         self.resolver = resolver
@@ -473,12 +474,14 @@ class OverrideResolver(Resolver):
     def close(self):
         self.resolver.close()
 
-    def resolve(self, host, port, *args, **kwargs):
-        if (host, port) in self.mapping:
+    def resolve(self, host, port, family=socket.AF_UNSPEC, *args, **kwargs):
+        if (host, port, family) in self.mapping:
+            host, port = self.mapping[(host, port, family)]
+        elif (host, port) in self.mapping:
             host, port = self.mapping[(host, port)]
         elif host in self.mapping:
             host = self.mapping[host]
-        return self.resolver.resolve(host, port, *args, **kwargs)
+        return self.resolver.resolve(host, port, family, *args, **kwargs)
 
 
 # These are the keyword arguments to ssl.wrap_socket that must be translated
index be581764b8bb8564288af7c07711d35f042cb158..fd284dad67bc5e91a5ee7e05c65f1ad22b97e2e6 100644 (file)
@@ -8,7 +8,7 @@ from subprocess import Popen
 import sys
 import time
 
-from tornado.netutil import BlockingResolver, ThreadedResolver, is_valid_ip, bind_sockets
+from tornado.netutil import BlockingResolver, OverrideResolver, ThreadedResolver, is_valid_ip, bind_sockets
 from tornado.stack_context import ExceptionStackContext
 from tornado.testing import AsyncTestCase, gen_test, bind_unused_port
 from tornado.test.util import unittest, skipIfNoNetwork
@@ -96,6 +96,26 @@ class BlockingResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin):
         super(BlockingResolverErrorTest, self).tearDown()
 
 
+class OverrideResolverTest(AsyncTestCase, _ResolverTestMixin):
+    def setUp(self):
+        super(OverrideResolverTest, self).setUp()
+        mapping = {
+            ('google.com', 80): ('1.2.3.4', 80),
+            ('google.com', 80, socket.AF_INET): ('1.2.3.4', 80),
+            ('google.com', 80, socket.AF_INET6): ('2a02:6b8:7c:40c:c51e:495f:e23a:3', 80)
+        }
+        self.resolver = OverrideResolver(BlockingResolver(), mapping)
+
+    def test_resolve_multiaddr(self):
+        self.resolver.resolve('google.com', 80, socket.AF_INET, callback=self.stop)
+        result = self.wait()
+        self.assertIn((socket.AF_INET, ('1.2.3.4', 80)), result)
+
+        self.resolver.resolve('google.com', 80, socket.AF_INET6, callback=self.stop)
+        result = self.wait()
+        self.assertIn((socket.AF_INET6, ('2a02:6b8:7c:40c:c51e:495f:e23a:3', 80, 0, 0)), result)
+
+
 @skipIfNoNetwork
 @unittest.skipIf(futures is None, "futures module not present")
 class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):