]> git.ipfire.org Git - thirdparty/bind9.git/commitdiff
Ensure uncaught exceptions kill custom servers
authorMichał Kępień <michal@isc.org>
Fri, 11 Apr 2025 14:14:57 +0000 (09:14 -0500)
committerMichał Kępień <michal@isc.org>
Fri, 11 Apr 2025 14:14:57 +0000 (09:14 -0500)
Uncaught exceptions raised by tasks running on event loops are not
handled by Python's default exception handler, so they do not cause
scripts to die immediately with a non-zero exit code.  Set up an
exception handler for AsyncServer code that makes any uncaught exception
the result of the Future that the top-level coroutine awaits.  This
ensures that any uncaught exceptions cause scripts based on AsyncServer
to immediately exit with an error, enabling the system test framework to
fail tests in which custom servers encounter unforeseen problems.

bin/tests/system/isctest/asyncserver.py

index dd2cf0c04cb025364dc1014e0cff7949f5ef8914..9d1a1d3cacef86c472d196f92db70e9458a2f699 100644 (file)
@@ -160,6 +160,7 @@ class AsyncServer:
             loop.run_until_complete(coroutine())
 
     async def _run(self) -> None:
+        self._setup_exception_handler()
         self._setup_signals()
         assert self._work_done
         await self._listen_udp()
@@ -177,9 +178,20 @@ class AsyncServer:
             loop = asyncio.get_event_loop()
         return loop
 
-    def _setup_signals(self) -> None:
+    def _setup_exception_handler(self) -> None:
         loop = self._get_asyncio_loop()
         self._work_done = loop.create_future()
+        loop.set_exception_handler(self._handle_exception)
+
+    def _handle_exception(
+        self, _: asyncio.AbstractEventLoop, context: Dict[str, Any]
+    ) -> None:
+        assert self._work_done
+        exception = context.get("exception", RuntimeError(context["message"]))
+        self._work_done.set_exception(exception)
+
+    def _setup_signals(self) -> None:
+        loop = self._get_asyncio_loop()
         loop.add_signal_handler(signal.SIGINT, functools.partial(self._signal_done))
         loop.add_signal_handler(signal.SIGTERM, functools.partial(self._signal_done))