]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
finish type constructor type checking
authorBob Halley <halley@dnspython.org>
Sat, 22 Aug 2020 17:42:12 +0000 (10:42 -0700)
committerBob Halley <halley@dnspython.org>
Sat, 22 Aug 2020 17:42:12 +0000 (10:42 -0700)
16 files changed:
dns/rdata.py
dns/rdtypes/ANY/AMTRELAY.py
dns/rdtypes/ANY/CAA.py
dns/rdtypes/ANY/CERT.py
dns/rdtypes/ANY/GPOS.py
dns/rdtypes/ANY/HIP.py
dns/rdtypes/ANY/LOC.py
dns/rdtypes/ANY/NSEC3PARAM.py
dns/rdtypes/ANY/OPT.py
dns/rdtypes/ANY/RRSIG.py
dns/rdtypes/ANY/SSHFP.py
dns/rdtypes/ANY/TKEY.py
dns/rdtypes/ANY/TLSA.py
dns/rdtypes/ANY/TSIG.py
dns/rdtypes/ANY/URI.py
dns/rdtypes/txtbase.py

index ee26ceba787437b49cd380e3e9fce125401bc677..042623daae57c72dee5256d2b6d44ca7d7f690c7 100644 (file)
@@ -335,11 +335,6 @@ class Rdata:
             object.__setattr__(rd, 'rdcomment', rdcomment)
         return rd
 
-    def as_value(self, value):
-        # This is the "additional type checking" placeholder that actually
-        # doesn't do any additional checking.
-        return value
-
     # Type checking and conversion helpers.  These are class methods as
     # they don't touch object state and may be useful to others.
 
@@ -396,6 +391,14 @@ class Rdata:
             raise ValueError('not a uint32')
         return value
 
+    @classmethod
+    def _as_uint48(cls, value):
+        if not isinstance(value, int):
+            raise ValueError('not an integer')
+        if value < 0 or value > 281474976710655:
+            raise ValueError('not a uint48')
+        return value
+
     @classmethod
     def _as_int(cls, value, low=None, high=None):
         if not isinstance(value, int):
index de6e99eb3c27c9c769d74438d5003c971694f491..5a7eb9145141689afc4c4698b508c52b6be8bea9 100644 (file)
@@ -38,10 +38,10 @@ class AMTRELAY(dns.rdata.Rdata):
                  relay_type, relay):
         super().__init__(rdclass, rdtype)
         Relay(relay_type, relay).check()
-        self.precedence = self.as_value(precedence)
-        self.discovery_optional = self.as_value(discovery_optional)
-        self.relay_type = self.as_value(relay_type)
-        self.relay = self.as_value(relay)
+        self.precedence = self._as_uint8(precedence)
+        self.discovery_optional = self._as_bool(discovery_optional)
+        self.relay_type = self._as_uint8(relay_type)
+        self.relay = relay
 
     def to_text(self, origin=None, relativize=True, **kw):
         relay = Relay(self.relay_type, self.relay).to_text(origin, relativize)
index 7c6dd0194eec7456361ef1aac86e35f1fd92750e..c86b45ea278b0e7ed91d01b8d2a57996f4220497 100644 (file)
@@ -34,9 +34,11 @@ class CAA(dns.rdata.Rdata):
 
     def __init__(self, rdclass, rdtype, flags, tag, value):
         super().__init__(rdclass, rdtype)
-        self.flags = self.as_value(flags)
-        self.tag = self.as_value(tag)
-        self.value = self.as_value(value)
+        self.flags = self._as_uint8(flags)
+        self.tag = self._as_bytes(tag, True, 255)
+        if not tag.isalnum():
+            raise ValueError("tag is not alphanumeric")
+        self.value = self._as_bytes(value)
 
     def to_text(self, origin=None, relativize=True, **kw):
         return '%u %s "%s"' % (self.flags,
@@ -48,10 +50,6 @@ class CAA(dns.rdata.Rdata):
                   relativize_to=None):
         flags = tok.get_uint8()
         tag = tok.get_string().encode()
-        if len(tag) > 255:
-            raise dns.exception.SyntaxError("tag too long")
-        if not tag.isalnum():
-            raise dns.exception.SyntaxError("tag is not alphanumeric")
         value = tok.get_string().encode()
         return cls(rdclass, rdtype, flags, tag, value)
 
index c78322a74294146f99da9312b6c90eebf48fcc98..6d663cc9ea0d9560ca315a294f50c1000c4e3403 100644 (file)
@@ -67,10 +67,10 @@ class CERT(dns.rdata.Rdata):
     def __init__(self, rdclass, rdtype, certificate_type, key_tag, algorithm,
                  certificate):
         super().__init__(rdclass, rdtype)
-        self.certificate_type = self.as_value(certificate_type)
-        self.key_tag = self.as_value(key_tag)
-        self.algorithm = self.as_value(algorithm)
-        self.certificate = self.as_value(certificate)
+        self.certificate_type = self._as_uint16(certificate_type)
+        self.key_tag = self._as_uint16(key_tag)
+        self.algorithm = self._as_uint8(algorithm)
+        self.certificate = self._as_bytes(certificate)
 
     def to_text(self, origin=None, relativize=True, **kw):
         certificate_type = _ctype_to_text(self.certificate_type)
index f9e3ed8defc52d6edb8a7c359404beaf7925ffa4..29fa8f8b031706b1695f866f23a561ba7f241533 100644 (file)
@@ -42,12 +42,6 @@ def _validate_float_string(what):
         raise dns.exception.FormError
 
 
-def _sanitize(value):
-    if isinstance(value, str):
-        return value.encode()
-    return value
-
-
 @dns.immutable.immutable
 class GPOS(dns.rdata.Rdata):
 
@@ -68,15 +62,15 @@ class GPOS(dns.rdata.Rdata):
         if isinstance(altitude, float) or \
            isinstance(altitude, int):
             altitude = str(altitude)
-        latitude = _sanitize(latitude)
-        longitude = _sanitize(longitude)
-        altitude = _sanitize(altitude)
+        latitude = self._as_bytes(latitude, True, 255)
+        longitude = self._as_bytes(longitude, True, 255)
+        altitude = self._as_bytes(altitude, True, 255)
         _validate_float_string(latitude)
         _validate_float_string(longitude)
         _validate_float_string(altitude)
-        self.latitude = self.as_value(latitude)
-        self.longitude = self.as_value(longitude)
-        self.altitude = self.as_value(altitude)
+        self.latitude = latitude
+        self.longitude = longitude
+        self.altitude = altitude
         flat = self.float_latitude
         if flat < -90.0 or flat > 90.0:
             raise dns.exception.FormError('bad latitude')
index 4ed350737fd82ae68b500e39c3c2d1d3f441191e..610260d4024e5682a91f58d0d67ab77346ee8a19 100644 (file)
@@ -36,10 +36,11 @@ class HIP(dns.rdata.Rdata):
 
     def __init__(self, rdclass, rdtype, hit, algorithm, key, servers):
         super().__init__(rdclass, rdtype)
-        self.hit = self.as_value(hit)
-        self.algorithm = self.as_value(algorithm)
-        self.key = self.as_value(key)
-        self.servers = self.as_value(dns.rdata._constify(servers))
+        self.hit = self._as_bytes(hit, True, 255)
+        self.algorithm = self._as_uint8(algorithm)
+        self.key = self._as_bytes(key, True)
+        self.servers = dns.rdata._constify([dns.rdata.Rdata._as_name(s)
+                                            for s in servers])
 
     def to_text(self, origin=None, relativize=True, **kw):
         hit = binascii.hexlify(self.hit).decode()
@@ -57,8 +58,6 @@ class HIP(dns.rdata.Rdata):
                   relativize_to=None):
         algorithm = tok.get_uint8()
         hit = binascii.unhexlify(tok.get_string().encode())
-        if len(hit) > 255:
-            raise dns.exception.SyntaxError("HIT too long")
         key = base64.b64decode(tok.get_string().encode())
         servers = []
         for token in tok.get_remaining():
index d2a7783c4c62465e482c513173e3f5c185d6589b..60b10b917f9320e73bf46dd1c50f44e9a9eccf95 100644 (file)
@@ -91,6 +91,19 @@ def _decode_size(what, desc):
     return base * pow(10, exponent)
 
 
+def _check_coordinate_list(value, low, high):
+    if value[0] < low or value[0] > high:
+        raise ValueError(f'not in range [{low}, {high}]')
+    if value[1] < 0 or value[1] > 59:
+        raise ValueError('bad minutes value')
+    if value[2] < 0 or value[2] > 59:
+        raise ValueError('bad seconds value')
+    if value[3] < 0 or value[3] > 999:
+        raise ValueError('bad milliseconds value')
+    if value[4] != 1 and value[4] != -1:
+        raise ValueError('bad hemisphere value')
+
+
 @dns.immutable.immutable
 class LOC(dns.rdata.Rdata):
 
@@ -117,16 +130,18 @@ class LOC(dns.rdata.Rdata):
             latitude = float(latitude)
         if isinstance(latitude, float):
             latitude = _float_to_tuple(latitude)
-        self.latitude = self.as_value(dns.rdata._constify(latitude))
+        _check_coordinate_list(latitude, -90, 90)
+        self.latitude = dns.rdata._constify(latitude)
         if isinstance(longitude, int):
             longitude = float(longitude)
         if isinstance(longitude, float):
             longitude = _float_to_tuple(longitude)
-        self.longitude = self.as_value(dns.rdata._constify(longitude))
-        self.altitude = self.as_value(float(altitude))
-        self.size = self.as_value(float(size))
-        self.horizontal_precision = self.as_value(float(hprec))
-        self.vertical_precision = self.as_value(float(vprec))
+        _check_coordinate_list(longitude, -180, 180)
+        self.longitude = dns.rdata._constify(longitude)
+        self.altitude = float(altitude)
+        self.size = float(size)
+        self.horizontal_precision = float(hprec)
+        self.vertical_precision = float(vprec)
 
     def to_text(self, origin=None, relativize=True, **kw):
         if self.latitude[4] > 0:
@@ -165,13 +180,9 @@ class LOC(dns.rdata.Rdata):
         vprec = _default_vprec
 
         latitude[0] = tok.get_int()
-        if latitude[0] > 90:
-            raise dns.exception.SyntaxError('latitude >= 90')
         t = tok.get_string()
         if t.isdigit():
             latitude[1] = int(t)
-            if latitude[1] >= 60:
-                raise dns.exception.SyntaxError('latitude minutes >= 60')
             t = tok.get_string()
             if '.' in t:
                 (seconds, milliseconds) = t.split('.')
@@ -179,8 +190,6 @@ class LOC(dns.rdata.Rdata):
                     raise dns.exception.SyntaxError(
                         'bad latitude seconds value')
                 latitude[2] = int(seconds)
-                if latitude[2] >= 60:
-                    raise dns.exception.SyntaxError('latitude seconds >= 60')
                 l = len(milliseconds)
                 if l == 0 or l > 3 or not milliseconds.isdigit():
                     raise dns.exception.SyntaxError(
@@ -202,13 +211,9 @@ class LOC(dns.rdata.Rdata):
             raise dns.exception.SyntaxError('bad latitude hemisphere value')
 
         longitude[0] = tok.get_int()
-        if longitude[0] > 180:
-            raise dns.exception.SyntaxError('longitude > 180')
         t = tok.get_string()
         if t.isdigit():
             longitude[1] = int(t)
-            if longitude[1] >= 60:
-                raise dns.exception.SyntaxError('longitude minutes >= 60')
             t = tok.get_string()
             if '.' in t:
                 (seconds, milliseconds) = t.split('.')
@@ -216,8 +221,6 @@ class LOC(dns.rdata.Rdata):
                     raise dns.exception.SyntaxError(
                         'bad longitude seconds value')
                 longitude[2] = int(seconds)
-                if longitude[2] >= 60:
-                    raise dns.exception.SyntaxError('longitude seconds >= 60')
                 l = len(milliseconds)
                 if l == 0 or l > 3 or not milliseconds.isdigit():
                     raise dns.exception.SyntaxError(
index d31116fc4e900302ae322fed07524baec400bd3e..299bf6ed169fa55fcf2048ed9efcac121a143a3f 100644 (file)
@@ -32,13 +32,10 @@ class NSEC3PARAM(dns.rdata.Rdata):
 
     def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt):
         super().__init__(rdclass, rdtype)
-        self.algorithm = self.as_value(algorithm)
-        self.flags = self.as_value(flags)
-        self.iterations = self.as_value(iterations)
-        if isinstance(salt, str):
-            self.salt = self.as_value(salt.encode())
-        else:
-            self.salt = self.as_value(salt)
+        self.algorithm = self._as_uint8(algorithm)
+        self.flags = self._as_uint8(flags)
+        self.iterations = self._as_uint16(iterations)
+        self.salt = self._as_bytes(salt, True, 255)
 
     def to_text(self, origin=None, relativize=True, **kw):
         if self.salt == b'':
index d962689405c6197e3585802bd526ef3e9c52cb93..1968ce26264f08cf6ddabb12c3a9e3796d3aa29a 100644 (file)
@@ -45,7 +45,10 @@ class OPT(dns.rdata.Rdata):
         """
 
         super().__init__(rdclass, rdtype)
-        self.options = self.as_value(dns.rdata._constify(options))
+        for option in options:
+            if not isinstance(option, dns.edns.Option):
+                raise ValueError('option is not a dns.edns.option')
+        self.options = dns.rdata._constify(options)
 
     def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
         for opt in self.options:
index 53cc55a6069663e8bffec3e6e0c2484eda12959c..93c7f105072ef369233456f620725e984241646e 100644 (file)
@@ -64,15 +64,15 @@ class RRSIG(dns.rdata.Rdata):
                  original_ttl, expiration, inception, key_tag, signer,
                  signature):
         super().__init__(rdclass, rdtype)
-        self.type_covered = self.as_value(type_covered)
-        self.algorithm = self.as_value(algorithm)
-        self.labels = self.as_value(labels)
-        self.original_ttl = self.as_value(original_ttl)
-        self.expiration = self.as_value(expiration)
-        self.inception = self.as_value(inception)
-        self.key_tag = self.as_value(key_tag)
-        self.signer = self.as_value(signer)
-        self.signature = self.as_value(signature)
+        self.type_covered = self._as_rdatatype(type_covered)
+        self.algorithm = dns.dnssec.Algorithm.make(algorithm)
+        self.labels = self._as_uint8(labels)
+        self.original_ttl = self._as_ttl(original_ttl)
+        self.expiration = self._as_uint32(expiration)
+        self.inception = self._as_uint32(inception)
+        self.key_tag = self._as_uint16(key_tag)
+        self.signer = self._as_name(signer)
+        self.signature = self._as_bytes(signature)
 
     def covers(self):
         return self.type_covered
index dd222b4f50a405443263dda53e24a4f69b3fa1d4..4fd917c39be4ffd4c92bb4e8f70ebfed09fdcdca 100644 (file)
@@ -35,9 +35,9 @@ class SSHFP(dns.rdata.Rdata):
     def __init__(self, rdclass, rdtype, algorithm, fp_type,
                  fingerprint):
         super().__init__(rdclass, rdtype)
-        self.algorithm = self.as_value(algorithm)
-        self.fp_type = self.as_value(fp_type)
-        self.fingerprint = self.as_value(fingerprint)
+        self.algorithm = self._as_uint8(algorithm)
+        self.fp_type = self._as_uint8(fp_type)
+        self.fingerprint = self._as_bytes(fingerprint, True)
 
     def to_text(self, origin=None, relativize=True, **kw):
         return '%d %d %s' % (self.algorithm,
index 871578a2ee3b8f92d43689ba3e1fb56a7380d36e..f8c47372338f5497df8d5f368b78fca4ed194b1b 100644 (file)
@@ -35,13 +35,13 @@ class TKEY(dns.rdata.Rdata):
     def __init__(self, rdclass, rdtype, algorithm, inception, expiration,
                  mode, error, key, other=b''):
         super().__init__(rdclass, rdtype)
-        self.algorithm = self.as_value(algorithm)
-        self.inception = self.as_value(inception)
-        self.expiration = self.as_value(expiration)
-        self.mode = self.as_value(mode)
-        self.error = self.as_value(error)
-        self.key = self.as_value(dns.rdata._constify(key))
-        self.other = self.as_value(dns.rdata._constify(other))
+        self.algorithm = self._as_name(algorithm)
+        self.inception = self._as_uint32(inception)
+        self.expiration = self._as_uint32(expiration)
+        self.mode = self._as_uint16(mode)
+        self.error = self._as_uint16(error)
+        self.key = self._as_bytes(key)
+        self.other = self._as_bytes(other)
 
     def to_text(self, origin=None, relativize=True, **kw):
         _algorithm = self.algorithm.choose_relativity(origin, relativize)
index 5e7dc1917510c0df6f43f30a23035352ed3c77fb..ad8dc8d9ff4927a1f237f6e2fbd3862b13eb7a10 100644 (file)
@@ -35,10 +35,10 @@ class TLSA(dns.rdata.Rdata):
     def __init__(self, rdclass, rdtype, usage, selector,
                  mtype, cert):
         super().__init__(rdclass, rdtype)
-        self.usage = self.as_value(usage)
-        self.selector = self.as_value(selector)
-        self.mtype = self.as_value(mtype)
-        self.cert = self.as_value(cert)
+        self.usage = self._as_uint8(usage)
+        self.selector = self._as_uint8(selector)
+        self.mtype = self._as_uint8(mtype)
+        self.cert = self._as_bytes(cert)
 
     def to_text(self, origin=None, relativize=True, **kw):
         return '%d %d %d %s' % (self.usage,
index e179d620a15b2480c64e037787b8d8a262b576d6..e49bf73a7086da7a5228b405638a7e85f423ece4 100644 (file)
@@ -20,6 +20,7 @@ import struct
 
 import dns.exception
 import dns.immutable
+import dns.rcode
 import dns.rdata
 
 
@@ -55,13 +56,13 @@ class TSIG(dns.rdata.Rdata):
         """
 
         super().__init__(rdclass, rdtype)
-        self.algorithm = self.as_value(algorithm)
-        self.time_signed = self.as_value(time_signed)
-        self.fudge = self.as_value(fudge)
-        self.mac = self.as_value(dns.rdata._constify(mac))
-        self.original_id = self.as_value(original_id)
-        self.error = self.as_value(error)
-        self.other = self.as_value(dns.rdata._constify(other))
+        self.algorithm = self._as_name(algorithm)
+        self.time_signed = self._as_uint48(time_signed)
+        self.fudge = self._as_uint16(fudge)
+        self.mac = dns.rdata._constify(self._as_bytes(mac))
+        self.original_id = self._as_uint16(original_id)
+        self.error = dns.rcode.Rcode.make(error)
+        self.other = self._as_bytes(other)
 
     def to_text(self, origin=None, relativize=True, **kw):
         algorithm = self.algorithm.choose_relativity(origin, relativize)
index 0892bd8b298f177549faeb0bc7081c998f15ddcd..60a43c88b05326878ba39cd62e8b7b2fbb83cb08 100644 (file)
@@ -35,14 +35,11 @@ class URI(dns.rdata.Rdata):
 
     def __init__(self, rdclass, rdtype, priority, weight, target):
         super().__init__(rdclass, rdtype)
-        self.priority = self.as_value(priority)
-        self.weight = self.as_value(weight)
-        if len(target) < 1:
+        self.priority = self._as_uint16(priority)
+        self.weight = self._as_uint16(weight)
+        self.target = self._as_bytes(target, True)
+        if len(self.target) == 0:
             raise dns.exception.SyntaxError("URI target cannot be empty")
-        if isinstance(target, str):
-            self.target = self.as_value(target.encode())
-        else:
-            self.target = self.as_value(target)
 
     def to_text(self, origin=None, relativize=True, **kw):
         return '%d %d "%s"' % (self.priority, self.weight,
index 6539c5a4effc558b02266ee6ab77d484efd9041a..a170ced36551b683c914c68006dd9374ddf39611 100644 (file)
@@ -46,12 +46,9 @@ class TXTBase(dns.rdata.Rdata):
             strings = (strings,)
         encoded_strings = []
         for string in strings:
-            if isinstance(string, str):
-                string = string.encode()
-            else:
-                string = dns.rdata._constify(string)
+            string = self._as_bytes(string, True, 255)
             encoded_strings.append(string)
-        self.strings = self.as_value(tuple(encoded_strings))
+        self.strings = dns.rdata._constify(encoded_strings)
 
     def to_text(self, origin=None, relativize=True, **kw):
         txt = ''