]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-93357: Start porting asyncio server test cases to IsolatedAsyncioTestCase (#93369)
authorOleg Iarygin <oleg@arhadthedev.net>
Tue, 4 Oct 2022 17:56:47 +0000 (21:56 +0400)
committerGitHub <noreply@github.com>
Tue, 4 Oct 2022 17:56:47 +0000 (10:56 -0700)
Lay the foundation for further work in `asyncio.test_streams`.

Lib/test/test_asyncio/test_streams.py

index 0c49099bc499a58298e5121134e364f68c271971..d1f8aef4bb9cbdadabe2d7c8ae3dcc12691e3c58 100644 (file)
@@ -566,46 +566,10 @@ class StreamTests(test_utils.TestCase):
         test_utils.run_briefly(self.loop)
         self.assertIs(stream._waiter, None)
 
-    def test_start_server(self):
-
-        class MyServer:
-
-            def __init__(self, loop):
-                self.server = None
-                self.loop = loop
-
-            async def handle_client(self, client_reader, client_writer):
-                data = await client_reader.readline()
-                client_writer.write(data)
-                await client_writer.drain()
-                client_writer.close()
-                await client_writer.wait_closed()
-
-            def start(self):
-                sock = socket.create_server(('127.0.0.1', 0))
-                self.server = self.loop.run_until_complete(
-                    asyncio.start_server(self.handle_client,
-                                         sock=sock))
-                return sock.getsockname()
-
-            def handle_client_callback(self, client_reader, client_writer):
-                self.loop.create_task(self.handle_client(client_reader,
-                                                         client_writer))
-
-            def start_callback(self):
-                sock = socket.create_server(('127.0.0.1', 0))
-                addr = sock.getsockname()
-                sock.close()
-                self.server = self.loop.run_until_complete(
-                    asyncio.start_server(self.handle_client_callback,
-                                         host=addr[0], port=addr[1]))
-                return addr
-
-            def stop(self):
-                if self.server is not None:
-                    self.server.close()
-                    self.loop.run_until_complete(self.server.wait_closed())
-                    self.server = None
+
+class NewStreamTests(unittest.IsolatedAsyncioTestCase):
+
+    async def test_start_server(self):
 
         async def client(addr):
             reader, writer = await asyncio.open_connection(*addr)
@@ -617,61 +581,43 @@ class StreamTests(test_utils.TestCase):
             await writer.wait_closed()
             return msgback
 
-        messages = []
-        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
-
-        # test the server variant with a coroutine as client handler
-        server = MyServer(self.loop)
-        addr = server.start()
-        msg = self.loop.run_until_complete(self.loop.create_task(client(addr)))
-        server.stop()
-        self.assertEqual(msg, b"hello world!\n")
+        async def handle_client(client_reader, client_writer):
+            data = await client_reader.readline()
+            client_writer.write(data)
+            await client_writer.drain()
+            client_writer.close()
+            await client_writer.wait_closed()
+
+        with self.subTest(msg="coroutine"):
+            server = await asyncio.start_server(
+                handle_client,
+                host=socket_helper.HOSTv4
+            )
+            addr = server.sockets[0].getsockname()
+            msg = await client(addr)
+            server.close()
+            await server.wait_closed()
+            self.assertEqual(msg, b"hello world!\n")
 
-        # test the server variant with a callback as client handler
-        server = MyServer(self.loop)
-        addr = server.start_callback()
-        msg = self.loop.run_until_complete(self.loop.create_task(client(addr)))
-        server.stop()
-        self.assertEqual(msg, b"hello world!\n")
+        with self.subTest(msg="callback"):
+            async def handle_client_callback(client_reader, client_writer):
+                asyncio.get_running_loop().create_task(
+                    handle_client(client_reader, client_writer)
+                )
 
-        self.assertEqual(messages, [])
+            server = await asyncio.start_server(
+                handle_client_callback,
+                host=socket_helper.HOSTv4
+            )
+            addr = server.sockets[0].getsockname()
+            reader, writer = await asyncio.open_connection(*addr)
+            msg = await client(addr)
+            server.close()
+            await server.wait_closed()
+            self.assertEqual(msg, b"hello world!\n")
 
     @socket_helper.skip_unless_bind_unix_socket
-    def test_start_unix_server(self):
-
-        class MyServer:
-
-            def __init__(self, loop, path):
-                self.server = None
-                self.loop = loop
-                self.path = path
-
-            async def handle_client(self, client_reader, client_writer):
-                data = await client_reader.readline()
-                client_writer.write(data)
-                await client_writer.drain()
-                client_writer.close()
-                await client_writer.wait_closed()
-
-            def start(self):
-                self.server = self.loop.run_until_complete(
-                    asyncio.start_unix_server(self.handle_client,
-                                              path=self.path))
-
-            def handle_client_callback(self, client_reader, client_writer):
-                self.loop.create_task(self.handle_client(client_reader,
-                                                         client_writer))
-
-            def start_callback(self):
-                start = asyncio.start_unix_server(self.handle_client_callback,
-                                                  path=self.path)
-                self.server = self.loop.run_until_complete(start)
-
-            def stop(self):
-                if self.server is not None:
-                    self.server.close()
-                    self.loop.run_until_complete(self.server.wait_closed())
-                    self.server = None
+    async def test_start_unix_server(self):
 
         async def client(path):
             reader, writer = await asyncio.open_unix_connection(path)
@@ -683,64 +629,42 @@ class StreamTests(test_utils.TestCase):
             await writer.wait_closed()
             return msgback
 
-        messages = []
-        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
-
-        # test the server variant with a coroutine as client handler
-        with test_utils.unix_socket_path() as path:
-            server = MyServer(self.loop, path)
-            server.start()
-            msg = self.loop.run_until_complete(
-                self.loop.create_task(client(path)))
-            server.stop()
-            self.assertEqual(msg, b"hello world!\n")
-
-        # test the server variant with a callback as client handler
-        with test_utils.unix_socket_path() as path:
-            server = MyServer(self.loop, path)
-            server.start_callback()
-            msg = self.loop.run_until_complete(
-                self.loop.create_task(client(path)))
-            server.stop()
-            self.assertEqual(msg, b"hello world!\n")
-
-        self.assertEqual(messages, [])
+        async def handle_client(client_reader, client_writer):
+            data = await client_reader.readline()
+            client_writer.write(data)
+            await client_writer.drain()
+            client_writer.close()
+            await client_writer.wait_closed()
+
+        with self.subTest(msg="coroutine"):
+            with test_utils.unix_socket_path() as path:
+                server = await asyncio.start_unix_server(
+                    handle_client,
+                    path=path
+                )
+                msg = await client(path)
+                server.close()
+                await server.wait_closed()
+                self.assertEqual(msg, b"hello world!\n")
+
+        with self.subTest(msg="callback"):
+            async def handle_client_callback(client_reader, client_writer):
+                asyncio.get_running_loop().create_task(
+                    handle_client(client_reader, client_writer)
+                )
+
+            with test_utils.unix_socket_path() as path:
+                server = await asyncio.start_unix_server(
+                    handle_client_callback,
+                    path=path
+                )
+                msg = await client(path)
+                server.close()
+                await server.wait_closed()
+                self.assertEqual(msg, b"hello world!\n")
 
     @unittest.skipIf(ssl is None, 'No ssl module')
-    def test_start_tls(self):
-
-        class MyServer:
-
-            def __init__(self, loop):
-                self.server = None
-                self.loop = loop
-
-            async def handle_client(self, client_reader, client_writer):
-                data1 = await client_reader.readline()
-                client_writer.write(data1)
-                await client_writer.drain()
-                assert client_writer.get_extra_info('sslcontext') is None
-                await client_writer.start_tls(
-                    test_utils.simple_server_sslcontext())
-                assert client_writer.get_extra_info('sslcontext') is not None
-                data2 = await client_reader.readline()
-                client_writer.write(data2)
-                await client_writer.drain()
-                client_writer.close()
-                await client_writer.wait_closed()
-
-            def start(self):
-                sock = socket.create_server(('127.0.0.1', 0))
-                self.server = self.loop.run_until_complete(
-                    asyncio.start_server(self.handle_client,
-                                         sock=sock))
-                return sock.getsockname()
-
-            def stop(self):
-                if self.server is not None:
-                    self.server.close()
-                    self.loop.run_until_complete(self.server.wait_closed())
-                    self.server = None
+    async def test_start_tls(self):
 
         async def client(addr):
             reader, writer = await asyncio.open_connection(*addr)
@@ -757,18 +681,49 @@ class StreamTests(test_utils.TestCase):
             await writer.wait_closed()
             return msgback1, msgback2
 
-        messages = []
-        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
-
-        server = MyServer(self.loop)
-        addr = server.start()
-        msg1, msg2 = self.loop.run_until_complete(client(addr))
-        server.stop()
-
-        self.assertEqual(messages, [])
+        async def handle_client(client_reader, client_writer):
+            data1 = await client_reader.readline()
+            client_writer.write(data1)
+            await client_writer.drain()
+            assert client_writer.get_extra_info('sslcontext') is None
+            await client_writer.start_tls(
+                test_utils.simple_server_sslcontext())
+            assert client_writer.get_extra_info('sslcontext') is not None
+
+            data2 = await client_reader.readline()
+            client_writer.write(data2)
+            await client_writer.drain()
+            client_writer.close()
+            await client_writer.wait_closed()
+
+        server = await asyncio.start_server(
+            handle_client,
+            host=socket_helper.HOSTv4
+        )
+        addr = server.sockets[0].getsockname()
+
+        msg1, msg2 = await client(addr)
+        server.close()
+        await server.wait_closed()
         self.assertEqual(msg1, b"hello world 1!\n")
         self.assertEqual(msg2, b"hello world 2!\n")
 
+
+class StreamTests2(test_utils.TestCase):
+
+    def setUp(self):
+        super().setUp()
+        self.loop = asyncio.new_event_loop()
+        self.set_event_loop(self.loop)
+
+    def tearDown(self):
+        # just in case if we have transport close callbacks
+        test_utils.run_briefly(self.loop)
+
+        self.loop.close()
+        gc.collect()
+        super().tearDown()
+
     @unittest.skipIf(sys.platform == 'win32', "Don't have pipes")
     def test_read_all_from_pipe_reader(self):
         # See asyncio issue 168.  This test is derived from the example
@@ -986,22 +941,20 @@ os.close(fd)
                 self.assertEqual(str(e), str(e2))
                 self.assertEqual(e.consumed, e2.consumed)
 
-    def test_wait_closed_on_close(self):
-        with test_utils.run_test_server() as httpd:
+    async def test_wait_closed_on_close(self):
+        async with test_utils.run_test_server() as httpd:
             rd, wr = self.loop.run_until_complete(
                 asyncio.open_connection(*httpd.address))
 
             wr.write(b'GET / HTTP/1.0\r\n\r\n')
-            f = rd.readline()
-            data = self.loop.run_until_complete(f)
+            data = await rd.readline()
             self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
-            f = rd.read()
-            data = self.loop.run_until_complete(f)
+            await rd.read()
             self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
             self.assertFalse(wr.is_closing())
             wr.close()
             self.assertTrue(wr.is_closing())
-            self.loop.run_until_complete(wr.wait_closed())
+            await wr.wait_closed()
 
     def test_wait_closed_on_close_with_unread_data(self):
         with test_utils.run_test_server() as httpd:
@@ -1057,15 +1010,10 @@ os.close(fd)
 
         self.assertEqual(messages, [])
 
-    def test_eof_feed_when_closing_writer(self):
+    async def test_eof_feed_when_closing_writer(self):
         # See http://bugs.python.org/issue35065
-        messages = []
-        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
-
-        with test_utils.run_test_server() as httpd:
-            rd, wr = self.loop.run_until_complete(
-                    asyncio.open_connection(*httpd.address))
-
+        async with test_utils.run_test_server() as httpd:
+            rd, wr = await asyncio.open_connection(*httpd.address)
             wr.close()
             f = wr.wait_closed()
             self.loop.run_until_complete(f)
@@ -1074,8 +1022,6 @@ os.close(fd)
             data = self.loop.run_until_complete(f)
             self.assertEqual(data, b'')
 
-        self.assertEqual(messages, [])
-
 
 if __name__ == '__main__':
     unittest.main()