]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add WMI-based method for finding resolver info on Windows. 722/head
authorBob Halley <halley@dnspython.org>
Tue, 9 Nov 2021 15:36:30 +0000 (07:36 -0800)
committerBob Halley <halley@dnspython.org>
Tue, 9 Nov 2021 15:36:30 +0000 (07:36 -0800)
dns/win32util.py [new file with mode: 0755]
tests/test_resolver.py

diff --git a/dns/win32util.py b/dns/win32util.py
new file mode 100755 (executable)
index 0000000..b2300dd
--- /dev/null
@@ -0,0 +1,248 @@
+import sys
+
+if sys.platform == 'win32':
+
+    import dns.name
+
+    _prefer_wmi = True
+
+    import winreg
+
+    try:
+        try:
+            import threading as _threading
+        except ImportError:  # pragma: no cover
+            import dummy_threading as _threading    # type: ignore
+        import pythoncom
+        import wmi
+        _have_wmi = True
+    except Exception as e:
+        _have_wmi = False
+
+    def _config_domain(domain):
+        # Sometimes DHCP servers add a '.' prefix to the default domain, and
+        # Windows just stores such values in the registry (see #687).
+        # Check for this and fix it.
+        if domain.startswith('.'):
+             domain = domain[1:]
+        return dns.name.from_text(domain)
+
+    class DnsInfo:
+        def __init__(self):
+            self.domain = None
+            self.nameservers = []
+            self.search = []
+
+    if _have_wmi:
+        class _WMIGetter(_threading.Thread):
+            def __init__(self):
+                super().__init__()
+                self.info = DnsInfo()
+
+            def run(self):
+                pythoncom.CoInitialize()
+                try:
+                    system = wmi.WMI()
+                    for interface in system.Win32_NetworkAdapterConfiguration():
+                        if interface.IPEnabled:
+                            self.info.domain = _config_domain(interface.DNSDomain)
+                            self.info.nameservers = list(interface.DNSServerSearchOrder)
+                            self.info.search = [dns.name.from_text(x) for x in
+                                                interface.DNSDomainSuffixSearchOrder]
+                            break
+                finally:
+                    pythoncom.CoUninitialize()
+
+            def get(self):
+                # We always run in a separate thread to avoid any issues with
+                # the COM threading model.
+                self.start()
+                self.join()
+                return self.info
+                
+
+        def get_dns_info_from_wmi():
+            getter = _WMIGetter()
+            return getter.get()
+    else:
+        class _WMIGetter:
+            pass
+        def get(self):
+            return None
+
+
+    class _RegistryGetter:
+        def __init__(self):
+            self.info = DnsInfo()
+
+        def _determine_split_char(self, entry):
+            #
+            # The windows registry irritatingly changes the list element
+            # delimiter in between ' ' and ',' (and vice-versa) in various
+            # versions of windows.
+            #
+            if entry.find(' ') >= 0:
+                split_char = ' '
+            elif entry.find(',') >= 0:
+                split_char = ','
+            else:
+                # probably a singleton; treat as a space-separated list.
+                split_char = ' '
+            return split_char
+
+        def _config_nameservers(self, nameservers):
+            split_char = self._determine_split_char(nameservers)
+            ns_list = nameservers.split(split_char)
+            for ns in ns_list:
+                if ns not in self.info.nameservers:
+                    self.info.nameservers.append(ns)
+
+        def _config_search(self, search):
+            split_char = self._determine_split_char(search)
+            search_list = search.split(split_char)
+            for s in search_list:
+                s = dns.name.from_text(s)
+                if s not in self.info.search:
+                    self.info.search.append(s)
+                
+        def _config_fromkey(self, key, always_try_domain):
+            try:
+                servers, _ = winreg.QueryValueEx(key, 'NameServer')
+            except WindowsError:
+                servers = None
+            if servers:
+                self._config_nameservers(servers)
+            if servers or always_try_domain:
+                try:
+                    dom, _ = winreg.QueryValueEx(key, 'Domain')
+                    if dom:
+                        self.info.domain = _config_domain(dom)
+                except WindowsError:
+                    pass
+            else:
+                try:
+                    servers, _ = winreg.QueryValueEx(key, 'DhcpNameServer')
+                except WindowsError:
+                    servers = None
+                if servers:
+                    self._config_nameservers(servers)
+                    try:
+                        dom, _ = winreg.QueryValueEx(key, 'DhcpDomain')
+                        if dom:
+                            self.info.domain = _config_domain(dom)
+                    except WindowsError:
+                        pass
+            try:
+                search, _ = winreg.QueryValueEx(key, 'SearchList')
+            except WindowsError:
+                search = None
+            if search is None:
+                try:
+                    search, _ = winreg.QueryValueEx(key, 'DhcpSearchList')
+                except WindowsError:
+                    search = None
+            if search:
+                self._config_search(search)
+
+        def _is_nic_enabled(self, lm, guid):
+            # Look in the Windows Registry to determine whether the network
+            # interface corresponding to the given guid is enabled.
+            #
+            # (Code contributed by Paul Marks, thanks!)
+            #
+            try:
+                # This hard-coded location seems to be consistent, at least
+                # from Windows 2000 through Vista.
+                connection_key = winreg.OpenKey(
+                    lm,
+                    r'SYSTEM\CurrentControlSet\Control\Network'
+                    r'\{4D36E972-E325-11CE-BFC1-08002BE10318}'
+                    r'\%s\Connection' % guid)
+
+                try:
+                    # The PnpInstanceID points to a key inside Enum
+                    (pnp_id, ttype) = winreg.QueryValueEx(
+                        connection_key, 'PnpInstanceID')
+
+                    if ttype != winreg.REG_SZ:
+                        raise ValueError  # pragma: no cover
+
+                    device_key = winreg.OpenKey(
+                        lm, r'SYSTEM\CurrentControlSet\Enum\%s' % pnp_id)
+
+                    try:
+                        # Get ConfigFlags for this device
+                        (flags, ttype) = winreg.QueryValueEx(
+                            device_key, 'ConfigFlags')
+
+                        if ttype != winreg.REG_DWORD:
+                            raise ValueError  # pragma: no cover
+
+                        # Based on experimentation, bit 0x1 indicates that the
+                        # device is disabled.
+                        #
+                        # XXXRTH I suspect we really want to & with 0x03 so
+                        # that CONFIGFLAGS_REMOVED devices are also ignored,
+                        # but we're shifting to WMI as ConfigFlags is not
+                        # supposed to be used.
+                        return not flags & 0x1
+
+                    finally:
+                        device_key.Close()
+                finally:
+                    connection_key.Close()
+            except Exception:  # pragma: no cover
+                return False
+
+        def get(self):
+            """Extract resolver configuration from the Windows registry."""
+
+            lm = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE)
+            try:
+                tcp_params = winreg.OpenKey(lm,
+                                            r'SYSTEM\CurrentControlSet'
+                                            r'\Services\Tcpip\Parameters')
+                try:
+                    self._config_fromkey(tcp_params, True)
+                finally:
+                    tcp_params.Close()
+                interfaces = winreg.OpenKey(lm,
+                                            r'SYSTEM\CurrentControlSet'
+                                            r'\Services\Tcpip\Parameters'
+                                            r'\Interfaces')
+                try:
+                    i = 0
+                    while True:
+                        try:
+                            guid = winreg.EnumKey(interfaces, i)
+                            i += 1
+                            key = winreg.OpenKey(interfaces, guid)
+                            try:
+                                if not self._is_nic_enabled(lm, guid):
+                                    continue
+                                self._config_fromkey(key, False)
+                            finally:
+                                key.Close()
+                        except EnvironmentError:
+                            break
+                finally:
+                    interfaces.Close()
+            finally:
+                lm.Close()
+            return self.info
+
+    def get_dns_info_from_registry():
+        """Extract resolver configuration from the Windows registry."""
+        getter = _RegistryGetter()
+        return getter.get()
+
+    if _have_wmi and _prefer_wmi:
+        _getter_class = _WMIGetter
+    else:
+        _getter_class = _RegistryGetter
+
+    def get_dns_info():
+        """Extract resolver configuration."""
+        getter = _getter_class()
+        return getter.get()
+
index eb8389328a92b5b07f4d3acd9f7e6f738fa03c5d..75822d9e68fbbf2038167006c2e12623ab535dde 100644 (file)
@@ -886,15 +886,13 @@ class ResolverMiscTestCase(unittest.TestCase):
         # not raising is the test
         res._compute_timeout(now + 0.5)
 
-    def test_configure_win32_domain(self):
-        # This is a win32-related test but it works on all platforms so we
-        # test it that way to make coverage analysis easier.
-        n = dns.name.from_text('home.')
-        res = dns.resolver.Resolver(configure=False)
-        res._config_win32_domain('home')
-        self.assertEqual(res.domain, n)
-        res._config_win32_domain('.home')
-        self.assertEqual(res.domain, n)
+    if sys.platform == 'win32':
+        def test_configure_win32_domain(self):
+            # This is a win32-related test but it works on all platforms so we
+            # test it that way to make coverage analysis easier.
+            n = dns.name.from_text('home.')
+            self.assertEqual(n, dns.win32util._config_domain('home'))
+            self.assertEqual(n, dns.win32util._config_domain('.home'))
 
 
 class ResolverNameserverValidTypeTestCase(unittest.TestCase):