]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Use Windows API to get dns nameservers (#1196)
authorSteven Silvester <steve.silvester@mongodb.com>
Sat, 26 Jul 2025 15:10:06 +0000 (10:10 -0500)
committerGitHub <noreply@github.com>
Sat, 26 Jul 2025 15:10:06 +0000 (08:10 -0700)
Add an option to use the Windows API to get dns nameservers for better accuracy.

.github/workflows/ci.yml
dns/win32util.py
doc/whatsnew.rst
tests/test_resolver.py

index 65342d22193c1c8acf791609b0d77dd70c09543a..e0026ec652ce177e91cf5b678c222040769fe704 100644 (file)
@@ -63,3 +63,9 @@ jobs:
     - name: Test with pytest
       run: |
         pytest
+
+    - name: Test with wmi
+      if: ${{ startsWith(matrix.os, 'windows') }}
+      run: |
+        python -m pip install ".[wmi]"
+        pytest
index 9ed3f11bcba03fc57de939aca45a82a1d067bae0..18f8a3ef9c7e7ac74b1afc08583df1d6427357fc 100644 (file)
@@ -1,15 +1,18 @@
+import os
 import sys
 
 import dns._features
 
+
 if sys.platform == "win32":
     from typing import Any
+    from enum import IntEnum
 
     import dns.name
 
-    _prefer_wmi = True
-
     import winreg  # pylint: disable=import-error
+    import ctypes  
+    import ctypes.wintypes as wintypes  
 
     # Keep pylint quiet on non-windows.
     try:
@@ -17,23 +20,10 @@ if sys.platform == "win32":
     except NameError:
         WindowsError = Exception
 
-    if dns._features.have("wmi"):
-        import threading
-
-        import pythoncom  # pylint: disable=import-error
-        import wmi  # pylint: disable=import-error
-
-        _have_wmi = True
-    else:
-        _have_wmi = False
-
-    def _config_domain(domain):
-        # Sometimes DHCP servers add a '.' prefix to the default domain, and
-        # Windows just stores such values in the registry (see #687).
-        # Check for this and fix it.
-        if domain.startswith("."):
-            domain = domain[1:]
-        return dns.name.from_text(domain)
+    class ConfigMethod(IntEnum):
+        Registry = 1
+        WMI = 2
+        Win32 = 3
 
     class DnsInfo:
         def __init__(self):
@@ -41,7 +31,16 @@ if sys.platform == "win32":
             self.nameservers = []
             self.search = []
 
-    if _have_wmi:
+    _config_method = ConfigMethod.Registry
+
+    if dns._features.have("wmi"):
+        import threading
+
+        import pythoncom  # pylint: disable=import-error
+        import wmi  # pylint: disable=import-error
+
+        # Prefer WMI by default if wmi is installed.
+        _config_method = ConfigMethod.WMI
 
         class _WMIGetter(threading.Thread):
             # pylint: disable=possibly-used-before-assignment
@@ -73,12 +72,20 @@ if sys.platform == "win32":
                 self.start()
                 self.join()
                 return self.info
-
+    
     else:
-
         class _WMIGetter:  # type: ignore
             pass
 
+
+    def _config_domain(domain):
+        # Sometimes DHCP servers add a '.' prefix to the default domain, and
+        # Windows just stores such values in the registry (see #687).
+        # Check for this and fix it.
+        if domain.startswith("."):
+            domain = domain[1:]
+        return dns.name.from_text(domain)
+        
     class _RegistryGetter:
         def __init__(self):
             self.info = DnsInfo()
@@ -230,13 +237,193 @@ if sys.platform == "win32":
                 lm.Close()
             return self.info
 
-    _getter_class: Any
-    if _have_wmi and _prefer_wmi:
-        _getter_class = _WMIGetter
-    else:
-        _getter_class = _RegistryGetter
+    class _Win32Getter(_RegistryGetter):
+
+        def get(self):
+            """Get the attributes using the Windows API.
+            """
+            # Load the IP Helper library  
+            # # https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getadaptersaddresses
+            IPHLPAPI = ctypes.WinDLL('Iphlpapi.dll')  
+
+            # Constants  
+            AF_UNSPEC = 0  
+            ERROR_SUCCESS = 0  
+            GAA_FLAG_INCLUDE_PREFIX = 0x00000010  
+            AF_INET = 2  
+            AF_INET6 = 23  
+            IF_TYPE_SOFTWARE_LOOPBACK = 24  
+            
+            # Define necessary structures  
+            class SOCKADDRV4(ctypes.Structure):  
+                _fields_ = [  
+                    ("sa_family", wintypes.USHORT),  
+                    ("sa_data", ctypes.c_ubyte * 14)  
+                ]  
+
+            class SOCKADDRV6(ctypes.Structure):  
+                _fields_ = [  
+                    ("sa_family", wintypes.USHORT),  
+                    ("sa_data", ctypes.c_ubyte * 16)  
+                ]  
+            
+            class SOCKET_ADDRESS(ctypes.Structure):  
+                _fields_ = [  
+                    ("lpSockaddr", ctypes.POINTER(SOCKADDRV4)),  
+                    ("iSockaddrLength", wintypes.INT)  
+                ]  
+            
+            class IP_ADAPTER_DNS_SERVER_ADDRESS(ctypes.Structure):  
+                pass  # Forward declaration  
+            
+            IP_ADAPTER_DNS_SERVER_ADDRESS._fields_ = [  
+                ("Length", wintypes.ULONG),  
+                ("Reserved", wintypes.DWORD),  
+                ("Next", ctypes.POINTER(IP_ADAPTER_DNS_SERVER_ADDRESS)),  
+                ("Address", SOCKET_ADDRESS)  
+            ]  
+
+            class IF_LUID(ctypes.Structure):  
+                _fields_ = [  
+                    ("Value", ctypes.c_ulonglong)  
+                ]  
+            
+            
+            class NET_IF_NETWORK_GUID(ctypes.Structure):  
+                _fields_ = [  
+                    ("Value", ctypes.c_ubyte * 16)  
+                ]  
+            
+            
+            class IP_ADAPTER_PREFIX_XP(ctypes.Structure):  
+                pass  # Left undefined here for simplicity  
+            
+            
+            class IP_ADAPTER_GATEWAY_ADDRESS_LH(ctypes.Structure):  
+                pass  # Left undefined here for simplicity  
+            
+            
+            class IP_ADAPTER_DNS_SUFFIX(ctypes.Structure):  
+                _fields_ = [("String", ctypes.c_wchar * 256), ("Next", ctypes.POINTER(ctypes.c_void_p))]  
+            
+            
+            class IP_ADAPTER_UNICAST_ADDRESS_LH(ctypes.Structure):  
+                pass  # Left undefined here for simplicity  
+            
+            
+            class IP_ADAPTER_MULTICAST_ADDRESS_XP(ctypes.Structure):  
+                pass  # Left undefined here for simplicity  
+            
+            
+            class IP_ADAPTER_ANYCAST_ADDRESS_XP(ctypes.Structure):  
+                pass  # Left undefined here for simplicity  
+            
+            
+            class IP_ADAPTER_DNS_SERVER_ADDRESS_XP(ctypes.Structure):  
+                pass  # Left undefined here for simplicity  
+            
+                        
+            class IP_ADAPTER_ADDRESSES(ctypes.Structure):  
+                pass  # Forward declaration  
+            
+            IP_ADAPTER_ADDRESSES._fields_ = [  
+                ("Length", wintypes.ULONG),  
+                ("IfIndex", wintypes.DWORD),  
+                ("Next", ctypes.POINTER(IP_ADAPTER_ADDRESSES)),  
+                ("AdapterName", ctypes.c_char_p),  
+                ("FirstUnicastAddress", ctypes.POINTER(SOCKET_ADDRESS)),  
+                ("FirstAnycastAddress", ctypes.POINTER(SOCKET_ADDRESS)),  
+                ("FirstMulticastAddress", ctypes.POINTER(SOCKET_ADDRESS)),  
+                ("FirstDnsServerAddress", ctypes.POINTER(IP_ADAPTER_DNS_SERVER_ADDRESS)),  
+                ("DnsSuffix", wintypes.LPWSTR),  
+                ("Description", wintypes.LPWSTR),  
+                ("FriendlyName", wintypes.LPWSTR),  
+                ("PhysicalAddress", ctypes.c_ubyte * 8),  
+                ("PhysicalAddressLength", wintypes.ULONG),  
+                ("Flags", wintypes.ULONG),  
+                ("Mtu", wintypes.ULONG),  
+                ("IfType", wintypes.ULONG),  
+                ("OperStatus", ctypes.c_uint),  
+                # Remaining fields removed for brevity  
+            ]  
+            
+            def format_ipv4(sockaddr_in):  
+                return ".".join(map(str, sockaddr_in.sa_data[2:6]))  
+            
+            def format_ipv6(sockaddr_in6):  
+                parts = [sockaddr_in6.sa_data[i] << 8 | sockaddr_in6.sa_data[i+1] for i in range(0, 16, 2)]  
+                return ":".join(f"{part:04x}" for part in parts)  
+            
+            buffer_size = ctypes.c_ulong(15000)  
+            while True:  
+                buffer = ctypes.create_string_buffer(buffer_size.value)  
+                
+                ret_val = IPHLPAPI.GetAdaptersAddresses(  
+                    AF_UNSPEC, GAA_FLAG_INCLUDE_PREFIX, None, buffer, ctypes.byref(buffer_size)  
+                )  
+                
+                if ret_val == ERROR_SUCCESS:  
+                    break  
+                elif ret_val != 0x6F:  # ERROR_BUFFER_OVERFLOW  
+                    print(f"Error retrieving adapter information: {ret_val}")  
+                    return  
+            
+            adapter_addresses = ctypes.cast(buffer, ctypes.POINTER(IP_ADAPTER_ADDRESSES))  
+            
+            current_adapter = adapter_addresses  
+            while current_adapter:  
+
+                # Skip non-operational adapters.
+                oper_status = current_adapter.contents.OperStatus  
+                if oper_status != 1:
+                    continue  
+
+                # Exclude loopback adapters.
+                if current_adapter.contents.IfType == IF_TYPE_SOFTWARE_LOOPBACK: 
+                    current_adapter = current_adapter.contents.Next
+                    continue
+
+                # Get the domain from the DnsSuffix attribute.
+                dns_suffix = current_adapter.contents.DnsSuffix  
+                if dns_suffix: 
+                    self.info.domain = dns_suffix 
+
+                current_dns_server = current_adapter.contents.FirstDnsServerAddress  
+                while current_dns_server:  
+                    sockaddr = current_dns_server.contents.Address.lpSockaddr  
+                    sockaddr_family = sockaddr.contents.sa_family  
+                        
+                    ip = None  
+                    if sockaddr_family == AF_INET:  # IPv4  
+                        ip = format_ipv4(sockaddr.contents)  
+                    elif sockaddr_family == AF_INET6:  # IPv6  
+                        sockaddr = ctypes.cast(sockaddr, ctypes.POINTER(SOCKADDRV6))
+                        ip = format_ipv6(sockaddr.contents)  
+                        
+                    if ip:  
+                        if ip not in self.info.nameservers:
+                            self.info.nameservers.append(ip)
+                        
+                    current_dns_server = current_dns_server.contents.Next  
+        
+                current_adapter = current_adapter.contents.Next  
+
+            # Use the registry getter to get the search info, since it is set at the system level.
+            registry_getter = _RegistryGetter()
+            info = registry_getter.get()
+            self.info.search = info.search
+            return self.info
+
+    def set_config_method(method: ConfigMethod) -> None:
+        global _config_method
+        _config_method = method
 
-    def get_dns_info():
+    def get_dns_info() -> DnsInfo:
         """Extract resolver configuration."""
-        getter = _getter_class()
+        if _config_method == ConfigMethod.Win32:
+            getter = _Win32Getter()
+        elif _config_method == ConfigMethod.WMI:
+            getter = _WMIGetter()
+        else:
+            getter = _RegistryGetter()
         return getter.get()
index 306f28dd43279695398e01b2a88c370655ab4710..522dc814c70365b2323e0d1a40ad4b7b2ab7ad01 100644 (file)
@@ -6,7 +6,9 @@ What's New in dnspython
 2.8.0 (in development)
 ----------------------
 
-* TBD
+* dns/win32util.py now supports explicitly setting the configuration method used to get 
+  system dns info, using the set_config_method() function.   There is a new configuration
+  method that uses the Win32 API, which can be set using set_config_method(ConfigMethod.Win32).
 
 2.7.0
 -----
index 752906415f148c9e49b16b1efb6f5fa4b40995fe..0c50ca7122cd41e4617440cfe5ea4aebeced2173 100644 (file)
@@ -23,6 +23,7 @@ import unittest
 from io import StringIO
 from unittest.mock import patch
 
+import dns.win32util
 import pytest
 
 import dns.e164
@@ -996,6 +997,11 @@ class ResolverMiscTestCase(unittest.TestCase):
             self.assertEqual(n, dns.win32util._config_domain("home"))
             self.assertEqual(n, dns.win32util._config_domain(".home"))
 
+        def test_set_config_method(self):
+            from dns.win32util import set_config_method, ConfigMethod
+            self.assertNotEqual(dns.win32util._config_method, dns.win32util.ConfigMethod.Win32)
+            dns.win32util.set_config_method(dns.win32util.ConfigMethod.Win32)
+            self.assertEqual(dns.win32util._config_method, dns.win32util.ConfigMethod.Win32)
 
 class ResolverNameserverValidTypeTestCase(unittest.TestCase):
     def test_set_nameservers_to_list(self):