]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-98388: add tests for happy eyeballs (#136368)
authorKumar Aditya <kumaraditya@python.org>
Mon, 7 Jul 2025 18:00:27 +0000 (23:30 +0530)
committerGitHub <noreply@github.com>
Mon, 7 Jul 2025 18:00:27 +0000 (23:30 +0530)
Lib/test/test_asyncio/test_base_events.py

index 12179eb0c9e27447731874a725b606e03b5a2b9c..22ae0ef35811564c6a942f96dbdd10dae1130415 100644 (file)
@@ -150,6 +150,29 @@ class BaseEventTests(test_utils.TestCase):
                                                    socket.SOCK_STREAM,
                                                    socket.IPPROTO_TCP))
 
+    def test_interleave_addrinfos(self):
+        self.maxDiff = None
+        SIX_A = (socket.AF_INET6, 0, 0, '', ('2001:db8::1', 1))
+        SIX_B = (socket.AF_INET6, 0, 0, '', ('2001:db8::2', 2))
+        SIX_C = (socket.AF_INET6, 0, 0, '', ('2001:db8::3', 3))
+        SIX_D = (socket.AF_INET6, 0, 0, '', ('2001:db8::4', 4))
+        FOUR_A = (socket.AF_INET, 0, 0, '', ('192.0.2.1', 5))
+        FOUR_B = (socket.AF_INET, 0, 0, '', ('192.0.2.2', 6))
+        FOUR_C = (socket.AF_INET, 0, 0, '', ('192.0.2.3', 7))
+        FOUR_D = (socket.AF_INET, 0, 0, '', ('192.0.2.4', 8))
+
+        addrinfos = [SIX_A, SIX_B, SIX_C, FOUR_A, FOUR_B, FOUR_C, FOUR_D, SIX_D]
+        expected = [SIX_A, FOUR_A, SIX_B, FOUR_B, SIX_C, FOUR_C, SIX_D, FOUR_D]
+
+        self.assertEqual(expected, base_events._interleave_addrinfos(addrinfos))
+
+        expected_fafc_2 = [SIX_A, SIX_B, FOUR_A, SIX_C, FOUR_B, SIX_D, FOUR_C, FOUR_D]
+        self.assertEqual(
+            expected_fafc_2,
+            base_events._interleave_addrinfos(addrinfos, first_address_family_count=2),
+        )
+
+
 
 class BaseEventLoopTests(test_utils.TestCase):
 
@@ -1053,6 +1076,71 @@ class BaseEventLoopTests(test_utils.TestCase):
             test_utils.run_briefly(self.loop)
             self.assertTrue(status['finalized'])
 
+    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'no IPv6 support')
+    @patch_socket
+    def test_create_connection_happy_eyeballs(self, m_socket):
+
+        class MyProto(asyncio.Protocol):
+            pass
+
+        async def getaddrinfo(*args, **kw):
+            return [(socket.AF_INET6, 0, 0, '', ('2001:db8::1', 1)),
+                    (socket.AF_INET, 0, 0, '', ('192.0.2.1', 5))]
+
+        async def sock_connect(sock, address):
+            if address[0] == '2001:db8::1':
+                await asyncio.sleep(1)
+            sock.connect(address)
+
+        loop = asyncio.new_event_loop()
+        loop._add_writer = mock.Mock()
+        loop._add_writer = mock.Mock()
+        loop._add_reader = mock.Mock()
+        loop.getaddrinfo = getaddrinfo
+        loop.sock_connect = sock_connect
+
+        coro = loop.create_connection(MyProto, 'example.com', 80, happy_eyeballs_delay=0.3)
+        transport, protocol = loop.run_until_complete(coro)
+        try:
+            sock = transport._sock
+            sock.connect.assert_called_with(('192.0.2.1', 5))
+        finally:
+            transport.close()
+            test_utils.run_briefly(loop)  # allow transport to close
+            loop.close()
+
+    @patch_socket
+    def test_create_connection_happy_eyeballs_ipv4_only(self, m_socket):
+
+        class MyProto(asyncio.Protocol):
+            pass
+
+        async def getaddrinfo(*args, **kw):
+            return [(socket.AF_INET, 0, 0, '', ('192.0.2.1', 5)),
+                    (socket.AF_INET, 0, 0, '', ('192.0.2.2', 6))]
+
+        async def sock_connect(sock, address):
+            if address[0] == '192.0.2.1':
+                await asyncio.sleep(1)
+            sock.connect(address)
+
+        loop = asyncio.new_event_loop()
+        loop._add_writer = mock.Mock()
+        loop._add_writer = mock.Mock()
+        loop._add_reader = mock.Mock()
+        loop.getaddrinfo = getaddrinfo
+        loop.sock_connect = sock_connect
+
+        coro = loop.create_connection(MyProto, 'example.com', 80, happy_eyeballs_delay=0.3)
+        transport, protocol = loop.run_until_complete(coro)
+        try:
+            sock = transport._sock
+            sock.connect.assert_called_with(('192.0.2.2', 6))
+        finally:
+            transport.close()
+            test_utils.run_briefly(loop)  # allow transport to close
+            loop.close()
+
 
 class MyProto(asyncio.Protocol):
     done = None