raise unittest.SkipTest("Windows-specific test")
-from _ctypes import COMError
+from _ctypes import COMError, CopyComPointer
from ctypes import HRESULT
)
+def create_shelllink_persist(typ):
+ ppst = typ()
+ # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cocreateinstance
+ ole32.CoCreateInstance(
+ byref(CLSID_ShellLink),
+ None,
+ CLSCTX_SERVER,
+ byref(IID_IPersist),
+ byref(ppst),
+ )
+ return ppst
+
+
class ForeignFunctionsThatWillCallComMethodsTests(unittest.TestCase):
def setUp(self):
# https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-coinitializeex
ole32.CoUninitialize()
gc.collect()
- @staticmethod
- def create_shelllink_persist(typ):
- ppst = typ()
- # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cocreateinstance
- ole32.CoCreateInstance(
- byref(CLSID_ShellLink),
- None,
- CLSCTX_SERVER,
- byref(IID_IPersist),
- byref(ppst),
- )
- return ppst
-
def test_without_paramflags_and_iid(self):
class IUnknown(c_void_p):
QueryInterface = proto_query_interface()
class IPersist(IUnknown):
GetClassID = proto_get_class_id()
- ppst = self.create_shelllink_persist(IPersist)
+ ppst = create_shelllink_persist(IPersist)
clsid = GUID()
hr_getclsid = ppst.GetClassID(byref(clsid))
class IPersist(IUnknown):
GetClassID = proto_get_class_id(((OUT, "pClassID"),))
- ppst = self.create_shelllink_persist(IPersist)
+ ppst = create_shelllink_persist(IPersist)
clsid = ppst.GetClassID()
self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
class IPersist(IUnknown):
GetClassID = proto_get_class_id(((OUT, "pClassID"),), IID_IPersist)
- ppst = self.create_shelllink_persist(IPersist)
+ ppst = create_shelllink_persist(IPersist)
clsid = ppst.GetClassID()
self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
self.assertEqual(0, ppst.Release())
+class CopyComPointerTests(unittest.TestCase):
+ def setUp(self):
+ ole32.CoInitializeEx(None, COINIT_APARTMENTTHREADED)
+
+ class IUnknown(c_void_p):
+ QueryInterface = proto_query_interface(None, IID_IUnknown)
+ AddRef = proto_add_ref()
+ Release = proto_release()
+
+ class IPersist(IUnknown):
+ GetClassID = proto_get_class_id(((OUT, "pClassID"),), IID_IPersist)
+
+ self.IUnknown = IUnknown
+ self.IPersist = IPersist
+
+ def tearDown(self):
+ ole32.CoUninitialize()
+ gc.collect()
+
+ def test_both_are_null(self):
+ src = self.IPersist()
+ dst = self.IPersist()
+
+ hr = CopyComPointer(src, byref(dst))
+
+ self.assertEqual(S_OK, hr)
+
+ self.assertIsNone(src.value)
+ self.assertIsNone(dst.value)
+
+ def test_src_is_nonnull_and_dest_is_null(self):
+ # The reference count of the COM pointer created by `CoCreateInstance`
+ # is initially 1.
+ src = create_shelllink_persist(self.IPersist)
+ dst = self.IPersist()
+
+ # `CopyComPointer` calls `AddRef` explicitly in the C implementation.
+ # The refcount of `src` is incremented from 1 to 2 here.
+ hr = CopyComPointer(src, byref(dst))
+
+ self.assertEqual(S_OK, hr)
+ self.assertEqual(src.value, dst.value)
+
+ # This indicates that the refcount was 2 before the `Release` call.
+ self.assertEqual(1, src.Release())
+
+ clsid = dst.GetClassID()
+ self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
+
+ self.assertEqual(0, dst.Release())
+
+ def test_src_is_null_and_dest_is_nonnull(self):
+ src = self.IPersist()
+ dst_orig = create_shelllink_persist(self.IPersist)
+ dst = self.IPersist()
+ CopyComPointer(dst_orig, byref(dst))
+ self.assertEqual(1, dst_orig.Release())
+
+ clsid = dst.GetClassID()
+ self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
+
+ # This does NOT affects the refcount of `dst_orig`.
+ hr = CopyComPointer(src, byref(dst))
+
+ self.assertEqual(S_OK, hr)
+ self.assertIsNone(dst.value)
+
+ with self.assertRaises(ValueError):
+ dst.GetClassID() # NULL COM pointer access
+
+ # This indicates that the refcount was 1 before the `Release` call.
+ self.assertEqual(0, dst_orig.Release())
+
+ def test_both_are_nonnull(self):
+ src = create_shelllink_persist(self.IPersist)
+ dst_orig = create_shelllink_persist(self.IPersist)
+ dst = self.IPersist()
+ CopyComPointer(dst_orig, byref(dst))
+ self.assertEqual(1, dst_orig.Release())
+
+ self.assertEqual(dst.value, dst_orig.value)
+ self.assertNotEqual(src.value, dst.value)
+
+ hr = CopyComPointer(src, byref(dst))
+
+ self.assertEqual(S_OK, hr)
+ self.assertEqual(src.value, dst.value)
+ self.assertNotEqual(dst.value, dst_orig.value)
+
+ self.assertEqual(1, src.Release())
+
+ clsid = dst.GetClassID()
+ self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
+
+ self.assertEqual(0, dst.Release())
+ self.assertEqual(0, dst_orig.Release())
+
+
if __name__ == '__main__':
unittest.main()