From: Steven Silvester Date: Sat, 26 Jul 2025 15:10:06 +0000 (-0500) Subject: Use Windows API to get dns nameservers (#1196) X-Git-Tag: v2.8.0rc1~23 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3140ae85bededcaab4dbf6b700c8d803c3a24e10;p=thirdparty%2Fdnspython.git Use Windows API to get dns nameservers (#1196) Add an option to use the Windows API to get dns nameservers for better accuracy. --- diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 65342d22..e0026ec6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -63,3 +63,9 @@ jobs: - name: Test with pytest run: | pytest + + - name: Test with wmi + if: ${{ startsWith(matrix.os, 'windows') }} + run: | + python -m pip install ".[wmi]" + pytest diff --git a/dns/win32util.py b/dns/win32util.py index 9ed3f11b..18f8a3ef 100644 --- a/dns/win32util.py +++ b/dns/win32util.py @@ -1,15 +1,18 @@ +import os import sys import dns._features + if sys.platform == "win32": from typing import Any + from enum import IntEnum import dns.name - _prefer_wmi = True - import winreg # pylint: disable=import-error + import ctypes + import ctypes.wintypes as wintypes # Keep pylint quiet on non-windows. try: @@ -17,23 +20,10 @@ if sys.platform == "win32": except NameError: WindowsError = Exception - if dns._features.have("wmi"): - import threading - - import pythoncom # pylint: disable=import-error - import wmi # pylint: disable=import-error - - _have_wmi = True - else: - _have_wmi = False - - def _config_domain(domain): - # Sometimes DHCP servers add a '.' prefix to the default domain, and - # Windows just stores such values in the registry (see #687). - # Check for this and fix it. - if domain.startswith("."): - domain = domain[1:] - return dns.name.from_text(domain) + class ConfigMethod(IntEnum): + Registry = 1 + WMI = 2 + Win32 = 3 class DnsInfo: def __init__(self): @@ -41,7 +31,16 @@ if sys.platform == "win32": self.nameservers = [] self.search = [] - if _have_wmi: + _config_method = ConfigMethod.Registry + + if dns._features.have("wmi"): + import threading + + import pythoncom # pylint: disable=import-error + import wmi # pylint: disable=import-error + + # Prefer WMI by default if wmi is installed. + _config_method = ConfigMethod.WMI class _WMIGetter(threading.Thread): # pylint: disable=possibly-used-before-assignment @@ -73,12 +72,20 @@ if sys.platform == "win32": self.start() self.join() return self.info - + else: - class _WMIGetter: # type: ignore pass + + def _config_domain(domain): + # Sometimes DHCP servers add a '.' prefix to the default domain, and + # Windows just stores such values in the registry (see #687). + # Check for this and fix it. + if domain.startswith("."): + domain = domain[1:] + return dns.name.from_text(domain) + class _RegistryGetter: def __init__(self): self.info = DnsInfo() @@ -230,13 +237,193 @@ if sys.platform == "win32": lm.Close() return self.info - _getter_class: Any - if _have_wmi and _prefer_wmi: - _getter_class = _WMIGetter - else: - _getter_class = _RegistryGetter + class _Win32Getter(_RegistryGetter): + + def get(self): + """Get the attributes using the Windows API. + """ + # Load the IP Helper library + # # https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getadaptersaddresses + IPHLPAPI = ctypes.WinDLL('Iphlpapi.dll') + + # Constants + AF_UNSPEC = 0 + ERROR_SUCCESS = 0 + GAA_FLAG_INCLUDE_PREFIX = 0x00000010 + AF_INET = 2 + AF_INET6 = 23 + IF_TYPE_SOFTWARE_LOOPBACK = 24 + + # Define necessary structures + class SOCKADDRV4(ctypes.Structure): + _fields_ = [ + ("sa_family", wintypes.USHORT), + ("sa_data", ctypes.c_ubyte * 14) + ] + + class SOCKADDRV6(ctypes.Structure): + _fields_ = [ + ("sa_family", wintypes.USHORT), + ("sa_data", ctypes.c_ubyte * 16) + ] + + class SOCKET_ADDRESS(ctypes.Structure): + _fields_ = [ + ("lpSockaddr", ctypes.POINTER(SOCKADDRV4)), + ("iSockaddrLength", wintypes.INT) + ] + + class IP_ADAPTER_DNS_SERVER_ADDRESS(ctypes.Structure): + pass # Forward declaration + + IP_ADAPTER_DNS_SERVER_ADDRESS._fields_ = [ + ("Length", wintypes.ULONG), + ("Reserved", wintypes.DWORD), + ("Next", ctypes.POINTER(IP_ADAPTER_DNS_SERVER_ADDRESS)), + ("Address", SOCKET_ADDRESS) + ] + + class IF_LUID(ctypes.Structure): + _fields_ = [ + ("Value", ctypes.c_ulonglong) + ] + + + class NET_IF_NETWORK_GUID(ctypes.Structure): + _fields_ = [ + ("Value", ctypes.c_ubyte * 16) + ] + + + class IP_ADAPTER_PREFIX_XP(ctypes.Structure): + pass # Left undefined here for simplicity + + + class IP_ADAPTER_GATEWAY_ADDRESS_LH(ctypes.Structure): + pass # Left undefined here for simplicity + + + class IP_ADAPTER_DNS_SUFFIX(ctypes.Structure): + _fields_ = [("String", ctypes.c_wchar * 256), ("Next", ctypes.POINTER(ctypes.c_void_p))] + + + class IP_ADAPTER_UNICAST_ADDRESS_LH(ctypes.Structure): + pass # Left undefined here for simplicity + + + class IP_ADAPTER_MULTICAST_ADDRESS_XP(ctypes.Structure): + pass # Left undefined here for simplicity + + + class IP_ADAPTER_ANYCAST_ADDRESS_XP(ctypes.Structure): + pass # Left undefined here for simplicity + + + class IP_ADAPTER_DNS_SERVER_ADDRESS_XP(ctypes.Structure): + pass # Left undefined here for simplicity + + + class IP_ADAPTER_ADDRESSES(ctypes.Structure): + pass # Forward declaration + + IP_ADAPTER_ADDRESSES._fields_ = [ + ("Length", wintypes.ULONG), + ("IfIndex", wintypes.DWORD), + ("Next", ctypes.POINTER(IP_ADAPTER_ADDRESSES)), + ("AdapterName", ctypes.c_char_p), + ("FirstUnicastAddress", ctypes.POINTER(SOCKET_ADDRESS)), + ("FirstAnycastAddress", ctypes.POINTER(SOCKET_ADDRESS)), + ("FirstMulticastAddress", ctypes.POINTER(SOCKET_ADDRESS)), + ("FirstDnsServerAddress", ctypes.POINTER(IP_ADAPTER_DNS_SERVER_ADDRESS)), + ("DnsSuffix", wintypes.LPWSTR), + ("Description", wintypes.LPWSTR), + ("FriendlyName", wintypes.LPWSTR), + ("PhysicalAddress", ctypes.c_ubyte * 8), + ("PhysicalAddressLength", wintypes.ULONG), + ("Flags", wintypes.ULONG), + ("Mtu", wintypes.ULONG), + ("IfType", wintypes.ULONG), + ("OperStatus", ctypes.c_uint), + # Remaining fields removed for brevity + ] + + def format_ipv4(sockaddr_in): + return ".".join(map(str, sockaddr_in.sa_data[2:6])) + + def format_ipv6(sockaddr_in6): + parts = [sockaddr_in6.sa_data[i] << 8 | sockaddr_in6.sa_data[i+1] for i in range(0, 16, 2)] + return ":".join(f"{part:04x}" for part in parts) + + buffer_size = ctypes.c_ulong(15000) + while True: + buffer = ctypes.create_string_buffer(buffer_size.value) + + ret_val = IPHLPAPI.GetAdaptersAddresses( + AF_UNSPEC, GAA_FLAG_INCLUDE_PREFIX, None, buffer, ctypes.byref(buffer_size) + ) + + if ret_val == ERROR_SUCCESS: + break + elif ret_val != 0x6F: # ERROR_BUFFER_OVERFLOW + print(f"Error retrieving adapter information: {ret_val}") + return + + adapter_addresses = ctypes.cast(buffer, ctypes.POINTER(IP_ADAPTER_ADDRESSES)) + + current_adapter = adapter_addresses + while current_adapter: + + # Skip non-operational adapters. + oper_status = current_adapter.contents.OperStatus + if oper_status != 1: + continue + + # Exclude loopback adapters. + if current_adapter.contents.IfType == IF_TYPE_SOFTWARE_LOOPBACK: + current_adapter = current_adapter.contents.Next + continue + + # Get the domain from the DnsSuffix attribute. + dns_suffix = current_adapter.contents.DnsSuffix + if dns_suffix: + self.info.domain = dns_suffix + + current_dns_server = current_adapter.contents.FirstDnsServerAddress + while current_dns_server: + sockaddr = current_dns_server.contents.Address.lpSockaddr + sockaddr_family = sockaddr.contents.sa_family + + ip = None + if sockaddr_family == AF_INET: # IPv4 + ip = format_ipv4(sockaddr.contents) + elif sockaddr_family == AF_INET6: # IPv6 + sockaddr = ctypes.cast(sockaddr, ctypes.POINTER(SOCKADDRV6)) + ip = format_ipv6(sockaddr.contents) + + if ip: + if ip not in self.info.nameservers: + self.info.nameservers.append(ip) + + current_dns_server = current_dns_server.contents.Next + + current_adapter = current_adapter.contents.Next + + # Use the registry getter to get the search info, since it is set at the system level. + registry_getter = _RegistryGetter() + info = registry_getter.get() + self.info.search = info.search + return self.info + + def set_config_method(method: ConfigMethod) -> None: + global _config_method + _config_method = method - def get_dns_info(): + def get_dns_info() -> DnsInfo: """Extract resolver configuration.""" - getter = _getter_class() + if _config_method == ConfigMethod.Win32: + getter = _Win32Getter() + elif _config_method == ConfigMethod.WMI: + getter = _WMIGetter() + else: + getter = _RegistryGetter() return getter.get() diff --git a/doc/whatsnew.rst b/doc/whatsnew.rst index 306f28dd..522dc814 100644 --- a/doc/whatsnew.rst +++ b/doc/whatsnew.rst @@ -6,7 +6,9 @@ What's New in dnspython 2.8.0 (in development) ---------------------- -* TBD +* dns/win32util.py now supports explicitly setting the configuration method used to get + system dns info, using the set_config_method() function. There is a new configuration + method that uses the Win32 API, which can be set using set_config_method(ConfigMethod.Win32). 2.7.0 ----- diff --git a/tests/test_resolver.py b/tests/test_resolver.py index 75290641..0c50ca71 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -23,6 +23,7 @@ import unittest from io import StringIO from unittest.mock import patch +import dns.win32util import pytest import dns.e164 @@ -996,6 +997,11 @@ class ResolverMiscTestCase(unittest.TestCase): self.assertEqual(n, dns.win32util._config_domain("home")) self.assertEqual(n, dns.win32util._config_domain(".home")) + def test_set_config_method(self): + from dns.win32util import set_config_method, ConfigMethod + self.assertNotEqual(dns.win32util._config_method, dns.win32util.ConfigMethod.Win32) + dns.win32util.set_config_method(dns.win32util.ConfigMethod.Win32) + self.assertEqual(dns.win32util._config_method, dns.win32util.ConfigMethod.Win32) class ResolverNameserverValidTypeTestCase(unittest.TestCase): def test_set_nameservers_to_list(self):