if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
fut = self.create_future()
- self._sock_accept(fut, False, sock)
+ self._sock_accept(fut, sock)
return await fut
- def _sock_accept(self, fut, registered, sock):
+ def _sock_accept(self, fut, sock):
fd = sock.fileno()
- if registered:
- self.remove_reader(fd)
- if fut.done():
- return
try:
conn, address = sock.accept()
conn.setblocking(False)
except (BlockingIOError, InterruptedError):
- self.add_reader(fd, self._sock_accept, fut, True, sock)
+ self._ensure_fd_no_transport(fd)
+ handle = self._add_reader(fd, self._sock_accept, fut, sock)
+ fut.add_done_callback(
+ functools.partial(self._sock_read_done, fd, handle=handle))
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
conn.close()
listener.close()
+ def test_cancel_sock_accept(self):
+ listener = socket.socket()
+ listener.setblocking(False)
+ listener.bind(('127.0.0.1', 0))
+ listener.listen(1)
+ sockaddr = listener.getsockname()
+ f = asyncio.wait_for(self.loop.sock_accept(listener), 0.1)
+ with self.assertRaises(asyncio.TimeoutError):
+ self.loop.run_until_complete(f)
+
+ listener.close()
+ client = socket.socket()
+ client.setblocking(False)
+ f = self.loop.sock_connect(client, sockaddr)
+ with self.assertRaises(ConnectionRefusedError):
+ self.loop.run_until_complete(f)
+
+ client.close()
+
def test_create_connection_sock(self):
with test_utils.run_test_server() as httpd:
sock = None