ssl_handshake_timeout, ssl_shutdown_timeout=None):
self._loop = loop
self._sockets = sockets
- # Weak references so we don't break Transport's ability to
- # detect abandoned transports
- self._clients = weakref.WeakSet()
+ self._active_count = 0
self._waiters = []
self._protocol_factory = protocol_factory
self._backlog = backlog
def __repr__(self):
return f'<{self.__class__.__name__} sockets={self.sockets!r}>'
- def _attach(self, transport):
+ def _attach(self):
assert self._sockets is not None
- self._clients.add(transport)
+ self._active_count += 1
- def _detach(self, transport):
- self._clients.discard(transport)
- if len(self._clients) == 0 and self._sockets is None:
+ def _detach(self):
+ assert self._active_count > 0
+ self._active_count -= 1
+ if self._active_count == 0 and self._sockets is None:
self._wakeup()
def _wakeup(self):
self._serving_forever_fut.cancel()
self._serving_forever_fut = None
- if len(self._clients) == 0:
+ if self._active_count == 0:
self._wakeup()
- def close_clients(self):
- for transport in self._clients.copy():
- transport.close()
-
- def abort_clients(self):
- for transport in self._clients.copy():
- transport.abort()
-
async def start_serving(self):
self._start_serving()
# Skip one loop iteration so that all 'loop.add_reader'
class TestServer2(unittest.IsolatedAsyncioTestCase):
async def test_wait_closed_basic(self):
- async def serve(rd, wr):
- try:
- await rd.read()
- finally:
- wr.close()
- await wr.wait_closed()
+ async def serve(*args):
+ pass
srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
self.addCleanup(srv.close)
self.assertFalse(task1.done())
# active count != 0, not closed: should block
- addr = srv.sockets[0].getsockname()
- (rd, wr) = await asyncio.open_connection(addr[0], addr[1])
+ srv._attach()
task2 = asyncio.create_task(srv.wait_closed())
await asyncio.sleep(0)
self.assertFalse(task1.done())
self.assertFalse(task2.done())
self.assertFalse(task3.done())
- wr.close()
- await wr.wait_closed()
+ srv._detach()
# active count == 0, closed: should unblock
await task1
await task2
async def test_wait_closed_race(self):
# Test a regression in 3.12.0, should be fixed in 3.12.1
- async def serve(rd, wr):
- try:
- await rd.read()
- finally:
- wr.close()
- await wr.wait_closed()
+ async def serve(*args):
+ pass
srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
self.addCleanup(srv.close)
task = asyncio.create_task(srv.wait_closed())
await asyncio.sleep(0)
self.assertFalse(task.done())
- addr = srv.sockets[0].getsockname()
- (rd, wr) = await asyncio.open_connection(addr[0], addr[1])
+ srv._attach()
loop = asyncio.get_running_loop()
loop.call_soon(srv.close)
- loop.call_soon(wr.close)
+ loop.call_soon(srv._detach)
await srv.wait_closed()
- async def test_close_clients(self):
- async def serve(rd, wr):
- try:
- await rd.read()
- finally:
- wr.close()
- await wr.wait_closed()
-
- srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
- self.addCleanup(srv.close)
-
- addr = srv.sockets[0].getsockname()
- (rd, wr) = await asyncio.open_connection(addr[0], addr[1])
- self.addCleanup(wr.close)
-
- task = asyncio.create_task(srv.wait_closed())
- await asyncio.sleep(0)
- self.assertFalse(task.done())
-
- srv.close()
- srv.close_clients()
- await asyncio.sleep(0)
- await asyncio.sleep(0)
- self.assertTrue(task.done())
-
- async def test_abort_clients(self):
- async def serve(rd, wr):
- nonlocal s_rd, s_wr
- s_rd = rd
- s_wr = wr
- await wr.wait_closed()
-
- s_rd = s_wr = None
- srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
- self.addCleanup(srv.close)
-
- addr = srv.sockets[0].getsockname()
- (c_rd, c_wr) = await asyncio.open_connection(addr[0], addr[1], limit=4096)
- self.addCleanup(c_wr.close)
-
- # Limit the socket buffers so we can reliably overfill them
- s_sock = s_wr.get_extra_info('socket')
- s_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 65536)
- c_sock = c_wr.get_extra_info('socket')
- c_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 65536)
-
- # Get the reader in to a paused state by sending more than twice
- # the configured limit
- s_wr.write(b'a' * 4096)
- s_wr.write(b'a' * 4096)
- s_wr.write(b'a' * 4096)
- while c_wr.transport.is_reading():
- await asyncio.sleep(0)
-
- # Get the writer in a waiting state by sending data until the
- # socket buffers are full on both server and client sockets and
- # the kernel stops accepting more data
- s_wr.write(b'a' * c_sock.getsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF))
- s_wr.write(b'a' * s_sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF))
- self.assertNotEqual(s_wr.transport.get_write_buffer_size(), 0)
-
- task = asyncio.create_task(srv.wait_closed())
- await asyncio.sleep(0)
- self.assertFalse(task.done())
- srv.close()
- srv.abort_clients()
- await asyncio.sleep(0)
- await asyncio.sleep(0)
- self.assertTrue(task.done())
# Test the various corner cases of Unix server socket removal