]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Simplify the Resolver interface.
authorBen Darnell <ben@bendarnell.com>
Sun, 24 Feb 2013 19:19:23 +0000 (14:19 -0500)
committerBen Darnell <ben@bendarnell.com>
Sun, 24 Feb 2013 19:19:23 +0000 (14:19 -0500)
Most callers do not need the full generality of getaddrinfo, and
most alternative resolver implementations cannot provide it.

tornado/netutil.py
tornado/platform/caresresolver.py
tornado/platform/twisted.py
tornado/simple_httpclient.py
tornado/test/netutil_test.py

index 576120188e1e90ac7ccd44aa79cfcf94562b9a5e..46e4227ecda84bc50ffae5eb631a7a0f4757970c 100644 (file)
@@ -168,17 +168,18 @@ class Resolver(Configurable):
     def configurable_default(cls):
         return BlockingResolver
 
-    def getaddrinfo(self, *args, **kwargs):
+    def resolve(self, host, port, family=socket.AF_UNSPEC, callback=None):
         """Resolves an address.
 
-        The arguments to this function are the same as to
-        `socket.getaddrinfo`, with the addition of an optional
-        keyword-only ``callback`` argument.
+        The ``host`` argument is a string which may be a hostname or a
+        literal IP address.
 
-        Returns a `Future` whose result is the same as the return
-        value of `socket.getaddrinfo`.  If a callback is passed,
-        it will be run with the `Future` as an argument when it
-        is complete.
+        Returns a `Future` whose result is a list of (family, address)
+        pairs, where address is a tuple suitable to pass to
+        `socket.connect` (i.e. a (host, port) pair for IPv4;
+        additional fields may be present for IPv6). If a callback is
+        passed, it will be run with the `Future` as an argument when
+        it is complete.
         """
         raise NotImplementedError()
 
@@ -189,8 +190,12 @@ class ExecutorResolver(Resolver):
         self.executor = executor or dummy_executor
 
     @run_on_executor
-    def getaddrinfo(self, *args, **kwargs):
-        return socket.getaddrinfo(*args, **kwargs)
+    def resolve(self, host, port, family=socket.AF_UNSPEC):
+        addrinfo = socket.getaddrinfo(host, port, family)
+        results = []
+        for family, socktype, proto, canonname, address in addrinfo:
+            results.append((family, address))
+        return results
 
 class BlockingResolver(ExecutorResolver):
     def initialize(self, io_loop=None):
@@ -214,12 +219,12 @@ class OverrideResolver(Resolver):
         self.resolver = resolver
         self.mapping = mapping
 
-    def getaddrinfo(self, host, port, *args, **kwargs):
+    def resolve(self, host, port, *args, **kwargs):
         if (host, port) in self.mapping:
             host, port = self.mapping[(host, port)]
         elif host in self.mapping:
             host = self.mapping[host]
-        return self.resolver.getaddrinfo(host, port, *args, **kwargs)
+        return self.resolver.resolve(host, port, *args, **kwargs)
 
 
 
index fe8cc377c4f723198e5bcb0b9787c5e2b0427db3..83fd86a8cdf4958ffe8cca99868e17e75f7a014e 100644 (file)
@@ -47,8 +47,7 @@ class CaresResolver(Resolver):
 
     @return_future
     @gen.engine
-    def getaddrinfo(self, host, port, family=0, socktype=0, proto=0,
-                    flags=0, callback=None):
+    def resolve(self, host, port, family=0, callback=None):
         if is_valid_ip(host):
             addresses = [host]
         else:
@@ -73,5 +72,5 @@ class CaresResolver(Resolver):
             if family != socket.AF_UNSPEC and family != address_family:
                 raise Exception('Requested socket family %d but got %d' %
                                 (family, address_family))
-            addrinfo.append((address_family, socktype, proto, '', (address, port)))
+            addrinfo.append((address_family, (address, port)))
         callback(addrinfo)
index 419b98832b7ee36739a8674ac2c08af4384f5b1f..eb6c402dcd65e8389c7608b263547272451cf139 100644 (file)
@@ -515,8 +515,7 @@ class TwistedResolver(Resolver):
 
     @return_future
     @gen.engine
-    def getaddrinfo(self, host, port, family=0, socktype=0, proto=0,
-                    flags=0, callback=None):
+    def resolve(self, host, port, family=0, callback=None):
         # getHostByName doesn't accept IP addresses, so if the input
         # looks like an IP address just return it immediately.
         if twisted.internet.abstract.isIPAddress(host):
@@ -538,6 +537,6 @@ class TwistedResolver(Resolver):
             raise Exception('Requested socket family %d but got %d' %
                             (family, resolved_family))
         result = [
-            (resolved_family, socktype, proto, '', (resolved, port)),
+            (resolved_family, (resolved, port)),
             ]
         self.io_loop.add_callback(callback, result)
index f33ed242efd37ca4f81662a223c9801b48f1e7aa..92021f2663c3ec29eac96d1e65397c2aae9981de 100644 (file)
@@ -150,12 +150,10 @@ class _HTTPConnection(object):
                 # so restrict to ipv4 by default.
                 af = socket.AF_INET
 
-            self.resolver.getaddrinfo(
-                host, port, af, socket.SOCK_STREAM, 0, 0,
-                callback=self._on_resolve)
+            self.resolver.resolve(host, port, af, callback=self._on_resolve)
 
     def _on_resolve(self, future):
-        af, socktype, proto, canonname, sockaddr = future.result()[0]
+        af, sockaddr = future.result()[0]
 
         if self.parsed.scheme == "https":
             ssl_options = {}
@@ -189,12 +187,12 @@ class _HTTPConnection(object):
                 # information.
                 ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3
 
-            self.stream = SSLIOStream(socket.socket(af, socktype, proto),
+            self.stream = SSLIOStream(socket.socket(af),
                                       io_loop=self.io_loop,
                                       ssl_options=ssl_options,
                                       max_buffer_size=self.max_buffer_size)
         else:
-            self.stream = IOStream(socket.socket(af, socktype, proto),
+            self.stream = IOStream(socket.socket(af),
                                    io_loop=self.io_loop,
                                    max_buffer_size=self.max_buffer_size)
         timeout = min(self.request.connect_timeout, self.request.request_timeout)
index 2ba5d8abf10a8d42ee61da821cb017007b2309b8..b31e1951943bfaa3251873b08330670d9a461a56 100644 (file)
@@ -14,29 +14,17 @@ except ImportError:
 
 class _ResolverTestMixin(object):
     def test_localhost(self):
-        # Note that windows returns IPPROTO_IP unless we specifically
-        # ask for IPPROTO_TCP (either will work to create a socket,
-        # but this test looks for an exact match)
-        self.resolver.getaddrinfo('localhost', 80, socket.AF_UNSPEC,
-                                  socket.SOCK_STREAM,
-                                  socket.IPPROTO_TCP,
-                                  callback=self.stop)
+        self.resolver.resolve('localhost', 80, callback=self.stop)
         future = self.wait()
-        self.assertIn(
-            (socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, '',
-             ('127.0.0.1', 80)),
-            future.result())
+        self.assertIn((socket.AF_INET, ('127.0.0.1', 80)),
+                      future.result())
 
     @gen_test
     def test_future_interface(self):
-        addrinfo = yield self.resolver.getaddrinfo(
-            'localhost', 80, socket.AF_UNSPEC,
-            socket.SOCK_STREAM, socket.IPPROTO_TCP)
-        self.assertIn(
-            (socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, '',
-             ('127.0.0.1', 80)),
-            addrinfo)
-
+        addrinfo = yield self.resolver.resolve('localhost', 80,
+                                               socket.AF_UNSPEC)
+        self.assertIn((socket.AF_INET, ('127.0.0.1', 80)),
+                      addrinfo)
 
 
 class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):