]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-127183: Add `_ctypes.CopyComPointer` tests (GH-127184)
authorJun Komoda <45822440+junkmd@users.noreply.github.com>
Mon, 25 Nov 2024 13:55:07 +0000 (22:55 +0900)
committerGitHub <noreply@github.com>
Mon, 25 Nov 2024 13:55:07 +0000 (14:55 +0100)
* Make `create_shelllink_persist` top level function.

* Add `CopyComPointerTests`.

* Add more tests.

* Update tests.

* Add assertions for `Release`'s return value.

Lib/test/test_ctypes/test_win32_com_foreign_func.py

index 651c9277d59af93629e4dadca165cfcdc6c42b38..8d217fc17efa02c50b59e1fad25bb8ebfce30171 100644 (file)
@@ -9,7 +9,7 @@ if sys.platform != "win32":
     raise unittest.SkipTest("Windows-specific test")
 
 
-from _ctypes import COMError
+from _ctypes import COMError, CopyComPointer
 from ctypes import HRESULT
 
 
@@ -78,6 +78,19 @@ proto_get_class_id = create_proto_com_method(
 )
 
 
+def create_shelllink_persist(typ):
+    ppst = typ()
+    # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cocreateinstance
+    ole32.CoCreateInstance(
+        byref(CLSID_ShellLink),
+        None,
+        CLSCTX_SERVER,
+        byref(IID_IPersist),
+        byref(ppst),
+    )
+    return ppst
+
+
 class ForeignFunctionsThatWillCallComMethodsTests(unittest.TestCase):
     def setUp(self):
         # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-coinitializeex
@@ -88,19 +101,6 @@ class ForeignFunctionsThatWillCallComMethodsTests(unittest.TestCase):
         ole32.CoUninitialize()
         gc.collect()
 
-    @staticmethod
-    def create_shelllink_persist(typ):
-        ppst = typ()
-        # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cocreateinstance
-        ole32.CoCreateInstance(
-            byref(CLSID_ShellLink),
-            None,
-            CLSCTX_SERVER,
-            byref(IID_IPersist),
-            byref(ppst),
-        )
-        return ppst
-
     def test_without_paramflags_and_iid(self):
         class IUnknown(c_void_p):
             QueryInterface = proto_query_interface()
@@ -110,7 +110,7 @@ class ForeignFunctionsThatWillCallComMethodsTests(unittest.TestCase):
         class IPersist(IUnknown):
             GetClassID = proto_get_class_id()
 
-        ppst = self.create_shelllink_persist(IPersist)
+        ppst = create_shelllink_persist(IPersist)
 
         clsid = GUID()
         hr_getclsid = ppst.GetClassID(byref(clsid))
@@ -142,7 +142,7 @@ class ForeignFunctionsThatWillCallComMethodsTests(unittest.TestCase):
         class IPersist(IUnknown):
             GetClassID = proto_get_class_id(((OUT, "pClassID"),))
 
-        ppst = self.create_shelllink_persist(IPersist)
+        ppst = create_shelllink_persist(IPersist)
 
         clsid = ppst.GetClassID()
         self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
@@ -167,7 +167,7 @@ class ForeignFunctionsThatWillCallComMethodsTests(unittest.TestCase):
         class IPersist(IUnknown):
             GetClassID = proto_get_class_id(((OUT, "pClassID"),), IID_IPersist)
 
-        ppst = self.create_shelllink_persist(IPersist)
+        ppst = create_shelllink_persist(IPersist)
 
         clsid = ppst.GetClassID()
         self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
@@ -184,5 +184,103 @@ class ForeignFunctionsThatWillCallComMethodsTests(unittest.TestCase):
         self.assertEqual(0, ppst.Release())
 
 
+class CopyComPointerTests(unittest.TestCase):
+    def setUp(self):
+        ole32.CoInitializeEx(None, COINIT_APARTMENTTHREADED)
+
+        class IUnknown(c_void_p):
+            QueryInterface = proto_query_interface(None, IID_IUnknown)
+            AddRef = proto_add_ref()
+            Release = proto_release()
+
+        class IPersist(IUnknown):
+            GetClassID = proto_get_class_id(((OUT, "pClassID"),), IID_IPersist)
+
+        self.IUnknown = IUnknown
+        self.IPersist = IPersist
+
+    def tearDown(self):
+        ole32.CoUninitialize()
+        gc.collect()
+
+    def test_both_are_null(self):
+        src = self.IPersist()
+        dst = self.IPersist()
+
+        hr = CopyComPointer(src, byref(dst))
+
+        self.assertEqual(S_OK, hr)
+
+        self.assertIsNone(src.value)
+        self.assertIsNone(dst.value)
+
+    def test_src_is_nonnull_and_dest_is_null(self):
+        # The reference count of the COM pointer created by `CoCreateInstance`
+        # is initially 1.
+        src = create_shelllink_persist(self.IPersist)
+        dst = self.IPersist()
+
+        # `CopyComPointer` calls `AddRef` explicitly in the C implementation.
+        # The refcount of `src` is incremented from 1 to 2 here.
+        hr = CopyComPointer(src, byref(dst))
+
+        self.assertEqual(S_OK, hr)
+        self.assertEqual(src.value, dst.value)
+
+        # This indicates that the refcount was 2 before the `Release` call.
+        self.assertEqual(1, src.Release())
+
+        clsid = dst.GetClassID()
+        self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
+
+        self.assertEqual(0, dst.Release())
+
+    def test_src_is_null_and_dest_is_nonnull(self):
+        src = self.IPersist()
+        dst_orig = create_shelllink_persist(self.IPersist)
+        dst = self.IPersist()
+        CopyComPointer(dst_orig, byref(dst))
+        self.assertEqual(1, dst_orig.Release())
+
+        clsid = dst.GetClassID()
+        self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
+
+        # This does NOT affects the refcount of `dst_orig`.
+        hr = CopyComPointer(src, byref(dst))
+
+        self.assertEqual(S_OK, hr)
+        self.assertIsNone(dst.value)
+
+        with self.assertRaises(ValueError):
+            dst.GetClassID()  # NULL COM pointer access
+
+        # This indicates that the refcount was 1 before the `Release` call.
+        self.assertEqual(0, dst_orig.Release())
+
+    def test_both_are_nonnull(self):
+        src = create_shelllink_persist(self.IPersist)
+        dst_orig = create_shelllink_persist(self.IPersist)
+        dst = self.IPersist()
+        CopyComPointer(dst_orig, byref(dst))
+        self.assertEqual(1, dst_orig.Release())
+
+        self.assertEqual(dst.value, dst_orig.value)
+        self.assertNotEqual(src.value, dst.value)
+
+        hr = CopyComPointer(src, byref(dst))
+
+        self.assertEqual(S_OK, hr)
+        self.assertEqual(src.value, dst.value)
+        self.assertNotEqual(dst.value, dst_orig.value)
+
+        self.assertEqual(1, src.Release())
+
+        clsid = dst.GetClassID()
+        self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
+
+        self.assertEqual(0, dst.Release())
+        self.assertEqual(0, dst_orig.Release())
+
+
 if __name__ == '__main__':
     unittest.main()