]> git.ipfire.org Git - thirdparty/mkosi.git/commitdiff
Make sure we handle any exceptions thrown in MkosiAsyncioThread()
authorDaan De Meyer <daan.j.demeyer@gmail.com>
Sat, 27 May 2023 08:54:20 +0000 (10:54 +0200)
committerDaan De Meyer <daan.j.demeyer@gmail.com>
Sat, 27 May 2023 08:56:21 +0000 (10:56 +0200)
Any unhandled exceptions from run() are by default just printed to
stderr. Let's make sure that these exceptions cause mkosi itself to
fail by catching them and re-raising them when the thread is joined.

mkosi/qemu.py
mkosi/run.py

index b38b1d6ff6fa24c7bff60c31733d04c9ea153762..175bea43af5a0f6a18bdd8b3040bd4449de09c90 100644 (file)
@@ -173,19 +173,15 @@ def vsock_notify_handler() -> Iterator[tuple[str, dict[str, str]]]:
         async def notify() -> None:
             loop = asyncio.get_running_loop()
 
-            try:
-                while True:
-                    s, _ = await loop.sock_accept(vsock)
+            while True:
+                s, _ = await loop.sock_accept(vsock)
 
-                    for msg in (await loop.sock_recv(s, 4096)).decode().split("\n"):
-                        if not msg:
-                            continue
+                for msg in (await loop.sock_recv(s, 4096)).decode().split("\n"):
+                    if not msg:
+                        continue
 
-                        k, _, v = msg.partition("=")
-                        messages[k] = v
-
-            except asyncio.CancelledError:
-                pass
+                    k, _, v = msg.partition("=")
+                    messages[k] = v
 
         with MkosiAsyncioThread(notify()):
             yield f"vsock:{socket.VMADDR_CID_HOST}:{vsock.getsockname()[1]}", messages
index e3f46b46498a8f0eee4442ffeb8d9beb5e4e65c2..c60dfb53148b86f8e55d72eef26296e099684331 100644 (file)
@@ -420,13 +420,14 @@ class MkosiAsyncioThread(threading.Thread):
     The default threading.Thread() is not interruptable, so we make our own version by using the concurrency
     feature in python that is interruptable, namely asyncio.
 
-    Additionally, we store the result of the coroutine in the result variable so it can be accessed easily
-    after the thread finishes.
+    Additionally, we store any exception that the coroutine raises and re-raise it in join() if no other
+    exception was raised before.
     """
 
     def __init__(self, target: Awaitable[Any], *args: Any, **kwargs: Any) -> None:
         self.target = target
         self.loop: queue.SimpleQueue[asyncio.AbstractEventLoop] = queue.SimpleQueue()
+        self.exc: queue.SimpleQueue[BaseException] = queue.SimpleQueue()
         super().__init__(*args, **kwargs)
 
     def run(self) -> None:
@@ -434,7 +435,12 @@ class MkosiAsyncioThread(threading.Thread):
             self.loop.put(asyncio.get_running_loop())
             await self.target
 
-        asyncio.run(wrapper())
+        try:
+            asyncio.run(wrapper())
+        except asyncio.CancelledError:
+            pass
+        except BaseException as e:
+            self.exc.put(e)
 
     def cancel(self) -> None:
         loop = self.loop.get()
@@ -454,3 +460,9 @@ class MkosiAsyncioThread(threading.Thread):
     ) -> None:
         self.cancel()
         self.join()
+
+        if type is None:
+            try:
+                raise self.exc.get_nowait()
+            except queue.Empty:
+                pass