+import os
import sys
import dns._features
+
if sys.platform == "win32":
from typing import Any
+ from enum import IntEnum
import dns.name
- _prefer_wmi = True
-
import winreg # pylint: disable=import-error
+ import ctypes
+ import ctypes.wintypes as wintypes
# Keep pylint quiet on non-windows.
try:
except NameError:
WindowsError = Exception
- if dns._features.have("wmi"):
- import threading
-
- import pythoncom # pylint: disable=import-error
- import wmi # pylint: disable=import-error
-
- _have_wmi = True
- else:
- _have_wmi = False
-
- def _config_domain(domain):
- # Sometimes DHCP servers add a '.' prefix to the default domain, and
- # Windows just stores such values in the registry (see #687).
- # Check for this and fix it.
- if domain.startswith("."):
- domain = domain[1:]
- return dns.name.from_text(domain)
+ class ConfigMethod(IntEnum):
+ Registry = 1
+ WMI = 2
+ Win32 = 3
class DnsInfo:
def __init__(self):
self.nameservers = []
self.search = []
- if _have_wmi:
+ _config_method = ConfigMethod.Registry
+
+ if dns._features.have("wmi"):
+ import threading
+
+ import pythoncom # pylint: disable=import-error
+ import wmi # pylint: disable=import-error
+
+ # Prefer WMI by default if wmi is installed.
+ _config_method = ConfigMethod.WMI
class _WMIGetter(threading.Thread):
# pylint: disable=possibly-used-before-assignment
self.start()
self.join()
return self.info
-
+
else:
-
class _WMIGetter: # type: ignore
pass
+
+ def _config_domain(domain):
+ # Sometimes DHCP servers add a '.' prefix to the default domain, and
+ # Windows just stores such values in the registry (see #687).
+ # Check for this and fix it.
+ if domain.startswith("."):
+ domain = domain[1:]
+ return dns.name.from_text(domain)
+
class _RegistryGetter:
def __init__(self):
self.info = DnsInfo()
lm.Close()
return self.info
- _getter_class: Any
- if _have_wmi and _prefer_wmi:
- _getter_class = _WMIGetter
- else:
- _getter_class = _RegistryGetter
+ class _Win32Getter(_RegistryGetter):
+
+ def get(self):
+ """Get the attributes using the Windows API.
+ """
+ # Load the IP Helper library
+ # # https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getadaptersaddresses
+ IPHLPAPI = ctypes.WinDLL('Iphlpapi.dll')
+
+ # Constants
+ AF_UNSPEC = 0
+ ERROR_SUCCESS = 0
+ GAA_FLAG_INCLUDE_PREFIX = 0x00000010
+ AF_INET = 2
+ AF_INET6 = 23
+ IF_TYPE_SOFTWARE_LOOPBACK = 24
+
+ # Define necessary structures
+ class SOCKADDRV4(ctypes.Structure):
+ _fields_ = [
+ ("sa_family", wintypes.USHORT),
+ ("sa_data", ctypes.c_ubyte * 14)
+ ]
+
+ class SOCKADDRV6(ctypes.Structure):
+ _fields_ = [
+ ("sa_family", wintypes.USHORT),
+ ("sa_data", ctypes.c_ubyte * 16)
+ ]
+
+ class SOCKET_ADDRESS(ctypes.Structure):
+ _fields_ = [
+ ("lpSockaddr", ctypes.POINTER(SOCKADDRV4)),
+ ("iSockaddrLength", wintypes.INT)
+ ]
+
+ class IP_ADAPTER_DNS_SERVER_ADDRESS(ctypes.Structure):
+ pass # Forward declaration
+
+ IP_ADAPTER_DNS_SERVER_ADDRESS._fields_ = [
+ ("Length", wintypes.ULONG),
+ ("Reserved", wintypes.DWORD),
+ ("Next", ctypes.POINTER(IP_ADAPTER_DNS_SERVER_ADDRESS)),
+ ("Address", SOCKET_ADDRESS)
+ ]
+
+ class IF_LUID(ctypes.Structure):
+ _fields_ = [
+ ("Value", ctypes.c_ulonglong)
+ ]
+
+
+ class NET_IF_NETWORK_GUID(ctypes.Structure):
+ _fields_ = [
+ ("Value", ctypes.c_ubyte * 16)
+ ]
+
+
+ class IP_ADAPTER_PREFIX_XP(ctypes.Structure):
+ pass # Left undefined here for simplicity
+
+
+ class IP_ADAPTER_GATEWAY_ADDRESS_LH(ctypes.Structure):
+ pass # Left undefined here for simplicity
+
+
+ class IP_ADAPTER_DNS_SUFFIX(ctypes.Structure):
+ _fields_ = [("String", ctypes.c_wchar * 256), ("Next", ctypes.POINTER(ctypes.c_void_p))]
+
+
+ class IP_ADAPTER_UNICAST_ADDRESS_LH(ctypes.Structure):
+ pass # Left undefined here for simplicity
+
+
+ class IP_ADAPTER_MULTICAST_ADDRESS_XP(ctypes.Structure):
+ pass # Left undefined here for simplicity
+
+
+ class IP_ADAPTER_ANYCAST_ADDRESS_XP(ctypes.Structure):
+ pass # Left undefined here for simplicity
+
+
+ class IP_ADAPTER_DNS_SERVER_ADDRESS_XP(ctypes.Structure):
+ pass # Left undefined here for simplicity
+
+
+ class IP_ADAPTER_ADDRESSES(ctypes.Structure):
+ pass # Forward declaration
+
+ IP_ADAPTER_ADDRESSES._fields_ = [
+ ("Length", wintypes.ULONG),
+ ("IfIndex", wintypes.DWORD),
+ ("Next", ctypes.POINTER(IP_ADAPTER_ADDRESSES)),
+ ("AdapterName", ctypes.c_char_p),
+ ("FirstUnicastAddress", ctypes.POINTER(SOCKET_ADDRESS)),
+ ("FirstAnycastAddress", ctypes.POINTER(SOCKET_ADDRESS)),
+ ("FirstMulticastAddress", ctypes.POINTER(SOCKET_ADDRESS)),
+ ("FirstDnsServerAddress", ctypes.POINTER(IP_ADAPTER_DNS_SERVER_ADDRESS)),
+ ("DnsSuffix", wintypes.LPWSTR),
+ ("Description", wintypes.LPWSTR),
+ ("FriendlyName", wintypes.LPWSTR),
+ ("PhysicalAddress", ctypes.c_ubyte * 8),
+ ("PhysicalAddressLength", wintypes.ULONG),
+ ("Flags", wintypes.ULONG),
+ ("Mtu", wintypes.ULONG),
+ ("IfType", wintypes.ULONG),
+ ("OperStatus", ctypes.c_uint),
+ # Remaining fields removed for brevity
+ ]
+
+ def format_ipv4(sockaddr_in):
+ return ".".join(map(str, sockaddr_in.sa_data[2:6]))
+
+ def format_ipv6(sockaddr_in6):
+ parts = [sockaddr_in6.sa_data[i] << 8 | sockaddr_in6.sa_data[i+1] for i in range(0, 16, 2)]
+ return ":".join(f"{part:04x}" for part in parts)
+
+ buffer_size = ctypes.c_ulong(15000)
+ while True:
+ buffer = ctypes.create_string_buffer(buffer_size.value)
+
+ ret_val = IPHLPAPI.GetAdaptersAddresses(
+ AF_UNSPEC, GAA_FLAG_INCLUDE_PREFIX, None, buffer, ctypes.byref(buffer_size)
+ )
+
+ if ret_val == ERROR_SUCCESS:
+ break
+ elif ret_val != 0x6F: # ERROR_BUFFER_OVERFLOW
+ print(f"Error retrieving adapter information: {ret_val}")
+ return
+
+ adapter_addresses = ctypes.cast(buffer, ctypes.POINTER(IP_ADAPTER_ADDRESSES))
+
+ current_adapter = adapter_addresses
+ while current_adapter:
+
+ # Skip non-operational adapters.
+ oper_status = current_adapter.contents.OperStatus
+ if oper_status != 1:
+ continue
+
+ # Exclude loopback adapters.
+ if current_adapter.contents.IfType == IF_TYPE_SOFTWARE_LOOPBACK:
+ current_adapter = current_adapter.contents.Next
+ continue
+
+ # Get the domain from the DnsSuffix attribute.
+ dns_suffix = current_adapter.contents.DnsSuffix
+ if dns_suffix:
+ self.info.domain = dns_suffix
+
+ current_dns_server = current_adapter.contents.FirstDnsServerAddress
+ while current_dns_server:
+ sockaddr = current_dns_server.contents.Address.lpSockaddr
+ sockaddr_family = sockaddr.contents.sa_family
+
+ ip = None
+ if sockaddr_family == AF_INET: # IPv4
+ ip = format_ipv4(sockaddr.contents)
+ elif sockaddr_family == AF_INET6: # IPv6
+ sockaddr = ctypes.cast(sockaddr, ctypes.POINTER(SOCKADDRV6))
+ ip = format_ipv6(sockaddr.contents)
+
+ if ip:
+ if ip not in self.info.nameservers:
+ self.info.nameservers.append(ip)
+
+ current_dns_server = current_dns_server.contents.Next
+
+ current_adapter = current_adapter.contents.Next
+
+ # Use the registry getter to get the search info, since it is set at the system level.
+ registry_getter = _RegistryGetter()
+ info = registry_getter.get()
+ self.info.search = info.search
+ return self.info
+
+ def set_config_method(method: ConfigMethod) -> None:
+ global _config_method
+ _config_method = method
- def get_dns_info():
+ def get_dns_info() -> DnsInfo:
"""Extract resolver configuration."""
- getter = _getter_class()
+ if _config_method == ConfigMethod.Win32:
+ getter = _Win32Getter()
+ elif _config_method == ConfigMethod.WMI:
+ getter = _WMIGetter()
+ else:
+ getter = _RegistryGetter()
return getter.get()