From: Daan De Meyer Date: Sat, 27 May 2023 08:54:20 +0000 (+0200) Subject: Make sure we handle any exceptions thrown in MkosiAsyncioThread() X-Git-Tag: v15~141^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5682da07c6bb91dc0db6c8f200ea23c89590d2fe;p=thirdparty%2Fmkosi.git Make sure we handle any exceptions thrown in MkosiAsyncioThread() Any unhandled exceptions from run() are by default just printed to stderr. Let's make sure that these exceptions cause mkosi itself to fail by catching them and re-raising them when the thread is joined. --- diff --git a/mkosi/qemu.py b/mkosi/qemu.py index b38b1d6ff..175bea43a 100644 --- a/mkosi/qemu.py +++ b/mkosi/qemu.py @@ -173,19 +173,15 @@ def vsock_notify_handler() -> Iterator[tuple[str, dict[str, str]]]: async def notify() -> None: loop = asyncio.get_running_loop() - try: - while True: - s, _ = await loop.sock_accept(vsock) + while True: + s, _ = await loop.sock_accept(vsock) - for msg in (await loop.sock_recv(s, 4096)).decode().split("\n"): - if not msg: - continue + for msg in (await loop.sock_recv(s, 4096)).decode().split("\n"): + if not msg: + continue - k, _, v = msg.partition("=") - messages[k] = v - - except asyncio.CancelledError: - pass + k, _, v = msg.partition("=") + messages[k] = v with MkosiAsyncioThread(notify()): yield f"vsock:{socket.VMADDR_CID_HOST}:{vsock.getsockname()[1]}", messages diff --git a/mkosi/run.py b/mkosi/run.py index e3f46b464..c60dfb531 100644 --- a/mkosi/run.py +++ b/mkosi/run.py @@ -420,13 +420,14 @@ class MkosiAsyncioThread(threading.Thread): The default threading.Thread() is not interruptable, so we make our own version by using the concurrency feature in python that is interruptable, namely asyncio. - Additionally, we store the result of the coroutine in the result variable so it can be accessed easily - after the thread finishes. + Additionally, we store any exception that the coroutine raises and re-raise it in join() if no other + exception was raised before. """ def __init__(self, target: Awaitable[Any], *args: Any, **kwargs: Any) -> None: self.target = target self.loop: queue.SimpleQueue[asyncio.AbstractEventLoop] = queue.SimpleQueue() + self.exc: queue.SimpleQueue[BaseException] = queue.SimpleQueue() super().__init__(*args, **kwargs) def run(self) -> None: @@ -434,7 +435,12 @@ class MkosiAsyncioThread(threading.Thread): self.loop.put(asyncio.get_running_loop()) await self.target - asyncio.run(wrapper()) + try: + asyncio.run(wrapper()) + except asyncio.CancelledError: + pass + except BaseException as e: + self.exc.put(e) def cancel(self) -> None: loop = self.loop.get() @@ -454,3 +460,9 @@ class MkosiAsyncioThread(threading.Thread): ) -> None: self.cancel() self.join() + + if type is None: + try: + raise self.exc.get_nowait() + except queue.Empty: + pass