]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
make sure Resolver.nameservers is a list or str
authorkimbo <kimballleavitt@gmail.com>
Thu, 26 Dec 2019 21:54:31 +0000 (14:54 -0700)
committerkimbo <kimballleavitt@gmail.com>
Thu, 26 Dec 2019 21:54:31 +0000 (14:54 -0700)
validate if assignment of Resolver.nameservers is a list, a str (in
which case it will be converted to a list), or None

dns/resolver.py
dns/resolver.pyi
tests/test_resolver.py

index c49598fe00627a102b598580167c970bce18c198..2ed0ebca76c3628a3196fd5739b93c48d26a3ec2 100644 (file)
@@ -1079,6 +1079,27 @@ class Resolver(object):
 
         self.flags = flags
 
+    @property
+    def nameservers(self):
+        return self._nameservers
+
+    @nameservers.setter
+    def nameservers(self, nameservers):
+        """
+        :param nameservers: can be a ``str``, ``list``, or None.
+        If it's a ``str``, it will converted to a list.
+        :raise ValueError: if `nameservers` is anything other than \
+        ``str``, ``list``, or None.
+        """
+        if isinstance(nameservers, str):
+            self._nameservers = [nameservers]
+        elif isinstance(nameservers, list):
+            self._nameservers = nameservers
+        elif nameservers is None:
+            self._nameservers = None
+        else:
+            raise ValueError('nameservers must be either a str, a list, or None'
+                             ' (not a {})'.format(type(nameservers)))
 
 #: The default resolver.
 default_resolver = None
index 06742fe5fa0e890d31f3a06ee2b3bd1fed41fc67..c68d04ae8ed58a47548002e8f606de3bbaac9cd7 100644 (file)
@@ -33,7 +33,7 @@ def zone_for_name(name, rdclass : int = rdataclass.IN, tcp=False, resolver : Opt
     ...
 
 class Resolver:
-    def __init__(self, configure):
+    def __init__(self, filename : Optional[str] = '/etc/resolv.conf', configure : Optional[bool] = True):
         self.nameservers : List[str]
     def query(self, qname : str, rdtype : Union[int,str] = rdatatype.A, rdclass : Union[int,str] = rdataclass.IN,
               tcp : bool = False, source : Optional[str] = None, raise_on_no_answer=True, source_port : int = 0):
index 1f788396876a6fc98ebce1bdbe3f093b3a22a195..ccfb04e96590c478f34af44b7418549123ced1a1 100644 (file)
@@ -404,5 +404,28 @@ class NXDOMAINExceptionTestCase(unittest.TestCase):
         self.assertTrue(e2.canonical_name == dns.name.from_text(cname2))
 
 
+class ResolverNameserverValidTypeTestCase(unittest.TestCase):
+    def test_set_nameserver_to_string(self):
+        resolver = dns.resolver.Resolver()
+        resolver.nameservers = '1.2.3.4'
+        self.assertEqual(resolver.nameservers, ['1.2.3.4'])
+
+    def test_set_nameserver_to_list(self):
+        resolver = dns.resolver.Resolver()
+        resolver.nameservers = ['1.2.3.4']
+        self.assertEqual(resolver.nameservers, ['1.2.3.4'])
+
+    def test_set_nameserver_to_None(self):
+        resolver = dns.resolver.Resolver()
+        resolver.nameservers = None
+        self.assertEqual(resolver.nameservers, None)
+
+    def test_set_nameserver_invalid_type(self):
+        resolver = dns.resolver.Resolver()
+        invalid_nameservers = [1234, (1, 2, 3, 4), {'invalid': 'nameserver'}]
+        for invalid_nameserver in invalid_nameservers:
+            with self.assertRaises(ValueError):
+                resolver.nameservers = invalid_nameserver
+
 if __name__ == '__main__':
     unittest.main()