]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
TCPClient: connect using specific source IP address
authorMircea Ulinic <mirucha@cloudflare.com>
Wed, 25 Jan 2017 11:45:31 +0000 (11:45 +0000)
committerMircea Ulinic <mirucha@cloudflare.com>
Wed, 25 Jan 2017 12:04:13 +0000 (12:04 +0000)
tornado/iostream.py
tornado/tcpclient.py

index bcf444148c0ff2b639571ccb8796d569d8c7bf53..458a6e1189774348cb0fe69f3efa5fab8709ff9d 100644 (file)
@@ -1023,7 +1023,8 @@ class IOStream(BaseIOStream):
     def write_to_fd(self, data):
         return self.socket.send(data)
 
-    def connect(self, address, callback=None, server_hostname=None):
+    def connect(self, address, callback=None, server_hostname=None,
+                source_ip=None):
         """Connects the socket to a remote address without blocking.
 
         May only be called if the socket passed to the constructor was
@@ -1047,6 +1048,9 @@ class IOStream(BaseIOStream):
         ``ssl_options``) and SNI (if supported; requires Python
         2.7.9+).
 
+        If ``source_ip`` is specified, will try to use a certain source
+        IP address to establish the connection.
+
         Note that it is safe to call `IOStream.write
         <BaseIOStream.write>` while the connection is pending, in
         which case the data will be written as soon as the connection
@@ -1068,6 +1072,17 @@ class IOStream(BaseIOStream):
             future = None
         else:
             future = self._connect_future = TracebackFuture()
+        if source_ip:
+            try:
+                self.socket.bind((source_ip, 0))
+                # port = 0, will not bind to a specific port
+            except socket.error as e:
+                gen_log.error("Unable to use the source IP %s",
+                                    source_ip)
+                gen_log.error("%s: %s", self.socket.fileno(), e)
+                # log the error and move on
+                # will try to connect using the loopback
+                gen_log.warning("Using the loopback IP address as source.")
         try:
             self.socket.connect(address)
         except socket.error as e:
@@ -1346,7 +1361,9 @@ class SSLIOStream(IOStream):
             return
         super(SSLIOStream, self)._handle_write()
 
-    def connect(self, address, callback=None, server_hostname=None):
+    def connect(self, address, callback=None, server_hostname=None,
+                source_ip=None):
+        # source_ip not used here
         self._server_hostname = server_hostname
         # Pass a dummy callback to super.connect(), which is slightly
         # more efficient than letting it return a Future we ignore.
index 111468607939fac7bd5bd1716fc46bc9663602c0..0b94da23deaecce3d0d2718e8468030c07fc0740 100644 (file)
@@ -47,7 +47,7 @@ class _Connector(object):
     http://tools.ietf.org/html/rfc6555
 
     """
-    def __init__(self, addrinfo, io_loop, connect):
+    def __init__(self, addrinfo, io_loop, connect, source_ip=None):
         self.io_loop = io_loop
         self.connect = connect
 
@@ -56,6 +56,7 @@ class _Connector(object):
         self.last_error = None
         self.remaining = len(addrinfo)
         self.primary_addrs, self.secondary_addrs = self.split(addrinfo)
+        self.source_ip = source_ip
 
     @staticmethod
     def split(addrinfo):
@@ -93,7 +94,7 @@ class _Connector(object):
                 self.future.set_exception(self.last_error or
                                           IOError("connection failed"))
             return
-        future = self.connect(af, addr)
+        future = self.connect(af, addr, source_ip=self.source_ip)
         future.add_done_callback(functools.partial(self.on_connect_done,
                                                    addrs, af, addr))
 
@@ -155,16 +156,23 @@ class TCPClient(object):
 
     @gen.coroutine
     def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None,
-                max_buffer_size=None):
+                max_buffer_size=None, source_ip=None):
         """Connect to the given host and port.
 
         Asynchronously returns an `.IOStream` (or `.SSLIOStream` if
         ``ssl_options`` is not None).
+
+        source_ip
+            Specify the source IP address to use when establishing
+            the connection. In case the user needs to resolve and
+            use a specific interface, it has to be handled outside
+            of Tornado as this depneds very much on the platform.
         """
         addrinfo = yield self.resolver.resolve(host, port, af)
         connector = _Connector(
             addrinfo, self.io_loop,
-            functools.partial(self._create_stream, max_buffer_size))
+            functools.partial(self._create_stream, max_buffer_size),
+            source_ip=source_ip)
         af, addr, stream = yield connector.start()
         # TODO: For better performance we could cache the (af, addr)
         # information here and re-use it on subsequent connections to
@@ -174,16 +182,16 @@ class TCPClient(object):
                                             server_hostname=host)
         raise gen.Return(stream)
 
-    def _create_stream(self, max_buffer_size, af, addr):
+    def _create_stream(self, max_buffer_size, af, addr, source_ip=None):
         # Always connect in plaintext; we'll convert to ssl if necessary
         # after one connection has completed.
         try:
             stream = IOStream(socket.socket(af),
-                            io_loop=self.io_loop,
-                            max_buffer_size=max_buffer_size)
+                              io_loop=self.io_loop,
+                              max_buffer_size=max_buffer_size)
         except socket.error as e:
             fu = Future()
             fu.set_exception(e)
             return fu
         else:
-            return stream.connect(addr)
+            return stream.connect(addr, source_ip=source_ip)