From: schlamar Date: Thu, 23 Jan 2014 11:23:53 +0000 (+0100) Subject: Fixed automatic port allocation in bind_sockets. X-Git-Tag: v4.0.0b1~114^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fpull%2F977%2Fhead;p=thirdparty%2Ftornado.git Fixed automatic port allocation in bind_sockets. bind_sockets should use the same port on IPv4 and IPv6 if port=None. --- diff --git a/tornado/netutil.py b/tornado/netutil.py index 7c89dc511..a4699138e 100644 --- a/tornado/netutil.py +++ b/tornado/netutil.py @@ -77,6 +77,7 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128, flags family = socket.AF_INET if flags is None: flags = socket.AI_PASSIVE + bound_port = None for res in set(socket.getaddrinfo(address, port, family, socket.SOCK_STREAM, 0, flags)): af, socktype, proto, canonname, sockaddr = res @@ -100,8 +101,16 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128, flags # Python 2.x on windows doesn't have IPPROTO_IPV6. if hasattr(socket, "IPPROTO_IPV6"): sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) + + # automatic port allocation with port=None + # should bind on the same port on IPv4 and IPv6 + host, requested_port = sockaddr[:2] + if requested_port == 0 and bound_port is not None: + sockaddr = tuple([host, bound_port] + list(sockaddr[2:])) + sock.setblocking(0) sock.bind(sockaddr) + bound_port = sock.getsockname()[1] sock.listen(backlog) sockets.append(sock) return sockets diff --git a/tornado/test/netutil_test.py b/tornado/test/netutil_test.py index ea8d51a53..9f707a8d2 100644 --- a/tornado/test/netutil_test.py +++ b/tornado/test/netutil_test.py @@ -6,7 +6,7 @@ from subprocess import Popen import sys import time -from tornado.netutil import BlockingResolver, ThreadedResolver, is_valid_ip +from tornado.netutil import BlockingResolver, ThreadedResolver, is_valid_ip, bind_sockets from tornado.stack_context import ExceptionStackContext from tornado.testing import AsyncTestCase, gen_test from tornado.test.util import unittest @@ -144,3 +144,10 @@ class IsValidIPTest(unittest.TestCase): self.assertTrue(not is_valid_ip(' ')) self.assertTrue(not is_valid_ip('\n')) self.assertTrue(not is_valid_ip('\x00')) + + +class TestPortAllocation(unittest.TestCase): + def test_same_port_allocation(self): + sockets = bind_sockets(None, 'localhost') + port = sockets[0].getsockname()[1] + self.assertTrue(all(s.getsockname()[1] == port for s in sockets[1:]))