]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-112278: Improve error handling in wmi module and tests (GH-117818)
authorSteve Dower <steve.dower@python.org>
Mon, 15 Apr 2024 15:43:28 +0000 (16:43 +0100)
committerGitHub <noreply@github.com>
Mon, 15 Apr 2024 15:43:28 +0000 (16:43 +0100)
Lib/test/test_wmi.py
PC/_wmimodule.cpp

index 3445702846d8a044cac8a54e96def03f75ee31cc..f667926d1f8ddfac4788626149d1d4e5d8a569ef 100644 (file)
@@ -1,17 +1,31 @@
 # Test the internal _wmi module on Windows
 # This is used by the platform module, and potentially others
 
+import time
 import unittest
-from test.support import import_helper, requires_resource
+from test.support import import_helper, requires_resource, LOOPBACK_TIMEOUT
 
 
 # Do this first so test will be skipped if module doesn't exist
 _wmi = import_helper.import_module('_wmi', required_on=['win'])
 
 
+def wmi_exec_query(query):
+    # gh-112278: WMI maybe slow response when first call.
+    try:
+        return _wmi.exec_query(query)
+    except BrokenPipeError:
+        pass
+    except WindowsError as e:
+        if e.winerror != 258:
+            raise
+    time.sleep(LOOPBACK_TIMEOUT)
+    return _wmi.exec_query(query)
+
+
 class WmiTests(unittest.TestCase):
     def test_wmi_query_os_version(self):
-        r = _wmi.exec_query("SELECT Version FROM Win32_OperatingSystem").split("\0")
+        r = wmi_exec_query("SELECT Version FROM Win32_OperatingSystem").split("\0")
         self.assertEqual(1, len(r))
         k, eq, v = r[0].partition("=")
         self.assertEqual("=", eq, r[0])
@@ -28,7 +42,7 @@ class WmiTests(unittest.TestCase):
     def test_wmi_query_error(self):
         # Invalid queries fail with OSError
         try:
-            _wmi.exec_query("SELECT InvalidColumnName FROM InvalidTableName")
+            wmi_exec_query("SELECT InvalidColumnName FROM InvalidTableName")
         except OSError as ex:
             if ex.winerror & 0xFFFFFFFF == 0x80041010:
                 # This is the expected error code. All others should fail the test
@@ -42,7 +56,7 @@ class WmiTests(unittest.TestCase):
     def test_wmi_query_not_select(self):
         # Queries other than SELECT are blocked to avoid potential exploits
         with self.assertRaises(ValueError):
-            _wmi.exec_query("not select, just in case someone tries something")
+            wmi_exec_query("not select, just in case someone tries something")
 
     @requires_resource('cpu')
     def test_wmi_query_overflow(self):
@@ -50,11 +64,11 @@ class WmiTests(unittest.TestCase):
         # Test multiple times to ensure consistency
         for _ in range(2):
             with self.assertRaises(OSError):
-                _wmi.exec_query("SELECT * FROM CIM_DataFile")
+                wmi_exec_query("SELECT * FROM CIM_DataFile")
 
     def test_wmi_query_multiple_rows(self):
         # Multiple instances should have an extra null separator
-        r = _wmi.exec_query("SELECT ProcessId FROM Win32_Process WHERE ProcessId < 1000")
+        r = wmi_exec_query("SELECT ProcessId FROM Win32_Process WHERE ProcessId < 1000")
         self.assertFalse(r.startswith("\0"), r)
         self.assertFalse(r.endswith("\0"), r)
         it = iter(r.split("\0"))
@@ -69,6 +83,6 @@ class WmiTests(unittest.TestCase):
         from concurrent.futures import ThreadPoolExecutor
         query = "SELECT ProcessId FROM Win32_Process WHERE ProcessId < 1000"
         with ThreadPoolExecutor(4) as pool:
-            task = [pool.submit(_wmi.exec_query, query) for _ in range(32)]
+            task = [pool.submit(wmi_exec_query, query) for _ in range(32)]
             for t in task:
                 self.assertRegex(t.result(), "ProcessId=")
index 310aa86d94d9b6b07afa75b82efec5b4c155a870..22ed05276e6f0716972068397c5bb81e23755743 100644 (file)
@@ -8,6 +8,11 @@
 // Version history
 //  2022-08: Initial contribution (Steve Dower)
 
+// clinic/_wmimodule.cpp.h uses internal pycore_modsupport.h API
+#ifndef Py_BUILD_CORE_BUILTIN
+#  define Py_BUILD_CORE_MODULE 1
+#endif
+
 #define _WIN32_DCOM
 #include <Windows.h>
 #include <comdef.h>
@@ -39,6 +44,8 @@ struct _query_data {
     LPCWSTR query;
     HANDLE writePipe;
     HANDLE readPipe;
+    HANDLE initEvent;
+    HANDLE connectEvent;
 };
 
 
@@ -75,12 +82,18 @@ _query_thread(LPVOID param)
             IID_IWbemLocator, (LPVOID *)&locator
         );
     }
+    if (SUCCEEDED(hr) && !SetEvent(data->initEvent)) {
+        hr = HRESULT_FROM_WIN32(GetLastError());
+    }
     if (SUCCEEDED(hr)) {
         hr = locator->ConnectServer(
             bstr_t(L"ROOT\\CIMV2"),
             NULL, NULL, 0, NULL, 0, 0, &services
         );
     }
+    if (SUCCEEDED(hr) && !SetEvent(data->connectEvent)) {
+        hr = HRESULT_FROM_WIN32(GetLastError());
+    }
     if (SUCCEEDED(hr)) {
         hr = CoSetProxyBlanket(
             services, RPC_C_AUTHN_WINNT, RPC_C_AUTHZ_NONE, NULL,
@@ -184,6 +197,24 @@ _query_thread(LPVOID param)
 }
 
 
+static DWORD
+wait_event(HANDLE event, DWORD timeout)
+{
+    DWORD err = 0;
+    switch (WaitForSingleObject(event, timeout)) {
+    case WAIT_OBJECT_0:
+        break;
+    case WAIT_TIMEOUT:
+        err = WAIT_TIMEOUT;
+        break;
+    default:
+        err = GetLastError();
+        break;
+    }
+    return err;
+}
+
+
 /*[clinic input]
 _wmi.exec_query
 
@@ -226,7 +257,11 @@ _wmi_exec_query_impl(PyObject *module, PyObject *query)
 
     Py_BEGIN_ALLOW_THREADS
 
-    if (!CreatePipe(&data.readPipe, &data.writePipe, NULL, 0)) {
+    data.initEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
+    data.connectEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
+    if (!data.initEvent || !data.connectEvent ||
+        !CreatePipe(&data.readPipe, &data.writePipe, NULL, 0))
+    {
         err = GetLastError();
     } else {
         hThread = CreateThread(NULL, 0, _query_thread, (LPVOID*)&data, 0, NULL);
@@ -238,6 +273,19 @@ _wmi_exec_query_impl(PyObject *module, PyObject *query)
         }
     }
 
+    // gh-112278: If current user doesn't have permission to query the WMI, the
+    // function IWbemLocator::ConnectServer will hang for 5 seconds, and there
+    // is no way to specify the timeout. So we use an Event object to simulate
+    // a timeout.  The initEvent will be set after COM initialization, it will
+    // take a longer time when first initialized.  The connectEvent will be set
+    // after connected to WMI.
+    if (!err) {
+        err = wait_event(data.initEvent, 1000);
+        if (!err) {
+            err = wait_event(data.connectEvent, 100);
+        }
+    }
+
     while (!err) {
         if (ReadFile(
             data.readPipe,
@@ -259,28 +307,35 @@ _wmi_exec_query_impl(PyObject *module, PyObject *query)
         CloseHandle(data.readPipe);
     }
 
-    // Allow the thread some time to clean up
-    switch (WaitForSingleObject(hThread, 1000)) {
-    case WAIT_OBJECT_0:
-        // Thread ended cleanly
-        if (!GetExitCodeThread(hThread, (LPDWORD)&err)) {
-            err = GetLastError();
-        }
-        break;
-    case WAIT_TIMEOUT:
-        // Probably stuck - there's not much we can do, unfortunately
-        if (err == 0 || err == ERROR_BROKEN_PIPE) {
-            err = WAIT_TIMEOUT;
+    if (hThread) {
+        // Allow the thread some time to clean up
+        int thread_err;
+        switch (WaitForSingleObject(hThread, 100)) {
+        case WAIT_OBJECT_0:
+            // Thread ended cleanly
+            if (!GetExitCodeThread(hThread, (LPDWORD)&thread_err)) {
+                thread_err = GetLastError();
+            }
+            break;
+        case WAIT_TIMEOUT:
+            // Probably stuck - there's not much we can do, unfortunately
+            thread_err = WAIT_TIMEOUT;
+            break;
+        default:
+            thread_err = GetLastError();
+            break;
         }
-        break;
-    default:
+        // An error on our side is more likely to be relevant than one from
+        // the thread, but if we don't have one on our side we'll take theirs.
         if (err == 0 || err == ERROR_BROKEN_PIPE) {
-            err = GetLastError();
+            err = thread_err;
         }
-        break;
+
+        CloseHandle(hThread);
     }
 
-    CloseHandle(hThread);
+    CloseHandle(data.initEvent);
+    CloseHandle(data.connectEvent);
     hThread = NULL;
 
     Py_END_ALLOW_THREADS