]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
use classmethod for Gateway factories
authorBob Halley <halley@dnspython.org>
Wed, 26 Aug 2020 13:49:56 +0000 (06:49 -0700)
committerBob Halley <halley@dnspython.org>
Wed, 26 Aug 2020 13:49:56 +0000 (06:49 -0700)
dns/rdtypes/ANY/AMTRELAY.py
dns/rdtypes/IN/IPSECKEY.py
dns/rdtypes/util.py
tests/test_rdata.py

index 5a7eb9145141689afc4c4698b508c52b6be8bea9..9f093deedfb3faf6596ef42c790462ce41ce30e2 100644 (file)
@@ -25,6 +25,11 @@ import dns.rdtypes.util
 class Relay(dns.rdtypes.util.Gateway):
     name = 'AMTRELAY relay'
 
+    @property
+    def relay(self):
+        return self.gateway
+
+
 @dns.immutable.immutable
 class AMTRELAY(dns.rdata.Rdata):
 
@@ -37,11 +42,11 @@ class AMTRELAY(dns.rdata.Rdata):
     def __init__(self, rdclass, rdtype, precedence, discovery_optional,
                  relay_type, relay):
         super().__init__(rdclass, rdtype)
-        Relay(relay_type, relay).check()
+        relay = Relay(relay_type, relay)
         self.precedence = self._as_uint8(precedence)
         self.discovery_optional = self._as_bool(discovery_optional)
-        self.relay_type = self._as_uint8(relay_type)
-        self.relay = relay
+        self.relay_type = relay.type
+        self.relay = relay.relay
 
     def to_text(self, origin=None, relativize=True, **kw):
         relay = Relay(self.relay_type, self.relay).to_text(origin, relativize)
@@ -59,10 +64,10 @@ class AMTRELAY(dns.rdata.Rdata):
         relay_type = tok.get_uint8()
         if relay_type > 0x7f:
             raise dns.exception.SyntaxError('expecting an integer <= 127')
-        relay = Relay(relay_type).from_text(tok, origin, relativize,
-                                            relativize_to)
+        relay = Relay.from_text(relay_type, tok, origin, relativize,
+                                relativize_to)
         return cls(rdclass, rdtype, precedence, discovery_optional, relay_type,
-                   relay)
+                   relay.relay)
 
     def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
         relay_type = self.relay_type | (self.discovery_optional << 7)
@@ -76,6 +81,6 @@ class AMTRELAY(dns.rdata.Rdata):
         (precedence, relay_type) = parser.get_struct('!BB')
         discovery_optional = bool(relay_type >> 7)
         relay_type &= 0x7f
-        relay = Relay(relay_type).from_wire_parser(parser, origin)
+        relay = Relay.from_wire_parser(relay_type, parser, origin)
         return cls(rdclass, rdtype, precedence, discovery_optional, relay_type,
-                   relay)
+                   relay.relay)
index b2dc2bee38014ff6aeb9ad7b6f0b5719f973fa28..ce899577174000aa9e0c3da85f809147c9cb748d 100644 (file)
@@ -38,11 +38,11 @@ class IPSECKEY(dns.rdata.Rdata):
     def __init__(self, rdclass, rdtype, precedence, gateway_type, algorithm,
                  gateway, key):
         super().__init__(rdclass, rdtype)
-        Gateway(gateway_type, gateway).check()
+        gateway = Gateway(gateway_type, gateway)
         self.precedence = self._as_uint8(precedence)
-        self.gateway_type = self._as_uint8(gateway_type)
+        self.gateway_type = gateway.type
         self.algorithm = self._as_uint8(algorithm)
-        self.gateway = gateway
+        self.gateway = gateway.gateway
         self.key = self._as_bytes(key)
 
     def to_text(self, origin=None, relativize=True, **kw):
@@ -58,12 +58,12 @@ class IPSECKEY(dns.rdata.Rdata):
         precedence = tok.get_uint8()
         gateway_type = tok.get_uint8()
         algorithm = tok.get_uint8()
-        gateway = Gateway(gateway_type).from_text(tok, origin, relativize,
-                                                  relativize_to)
+        gateway = Gateway.from_text(gateway_type, tok, origin, relativize,
+                                    relativize_to)
         b64 = tok.concatenate_remaining_identifiers().encode()
         key = base64.b64decode(b64)
         return cls(rdclass, rdtype, precedence, gateway_type, algorithm,
-                   gateway, key)
+                   gateway.gateway, key)
 
     def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
         header = struct.pack("!BBB", self.precedence, self.gateway_type,
@@ -77,7 +77,7 @@ class IPSECKEY(dns.rdata.Rdata):
     def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
         header = parser.get_struct('!BBB')
         gateway_type = header[1]
-        gateway = Gateway(gateway_type).from_wire_parser(parser, origin)
+        gateway = Gateway.from_wire_parser(gateway_type, parser, origin)
         key = parser.get_remaining()
         return cls(rdclass, rdtype, header[0], gateway_type, header[2],
-                   gateway, key)
+                   gateway.gateway, key)
index aed67d76340f8279ea0afc5e1d7726b6fd3b9227..30be37dcc01d905a8de879494a3807e32647203f 100644 (file)
 import struct
 
 import dns.exception
-import dns.name
 import dns.ipv4
 import dns.ipv6
+import dns.name
+import dns.rdata
+
 
 class Gateway:
     """A helper class for the IPSECKEY gateway and AMTRELAY relay fields"""
     name = ""
 
     def __init__(self, type, gateway=None):
-        self.type = type
+        self.type = dns.rdata.Rdata._as_uint8(type)
         self.gateway = gateway
+        self._check()
 
-    def _invalid_type(self):
-        return f"invalid {self.name} type: {self.type}"
+    @classmethod
+    def _invalid_type(cls, gateway_type):
+        return f"invalid {cls.name} type: {gateway_type}"
 
-    def check(self):
+    def _check(self):
         if self.type == 0:
             if self.gateway not in (".", None):
                 raise SyntaxError(f"invalid {self.name} for type 0")
@@ -48,7 +52,7 @@ class Gateway:
             if not isinstance(self.gateway, dns.name.Name):
                 raise SyntaxError(f"invalid {self.name}; not a name")
         else:
-            raise SyntaxError(self._invalid_type())
+            raise SyntaxError(self._invalid_type(self.type))
 
     def to_text(self, origin=None, relativize=True):
         if self.type == 0:
@@ -58,15 +62,19 @@ class Gateway:
         elif self.type == 3:
             return str(self.gateway.choose_relativity(origin, relativize))
         else:
-            raise ValueError(self._invalid_type())
+            raise ValueError(self._invalid_type(self.type))  # pragma: no cover
 
-    def from_text(self, tok, origin=None, relativize=True, relativize_to=None):
-        if self.type in (0, 1, 2):
-            return tok.get_string()
-        elif self.type == 3:
-            return tok.get_name(origin, relativize, relativize_to)
+    @classmethod
+    def from_text(cls, gateway_type, tok, origin=None, relativize=True,
+                  relativize_to=None):
+        if gateway_type in (0, 1, 2):
+            gateway = tok.get_string()
+        elif gateway_type == 3:
+            gateway = tok.get_name(origin, relativize, relativize_to)
         else:
-            raise dns.exception.SyntaxError(self._invalid_type())
+            raise dns.exception.SyntaxError(
+                cls._invalid_type(gateway_type))  # pragma: no cover
+        return cls(gateway_type, gateway)
 
     # pylint: disable=unused-argument
     def to_wire(self, file, compress=None, origin=None, canonicalize=False):
@@ -79,20 +87,23 @@ class Gateway:
         elif self.type == 3:
             self.gateway.to_wire(file, None, origin, False)
         else:
-            raise ValueError(self._invalid_type())
+            raise ValueError(self._invalid_type(self.type))  # pragma: no cover
     # pylint: enable=unused-argument
 
-    def from_wire_parser(self, parser, origin=None):
-        if self.type == 0:
-            return None
-        elif self.type == 1:
-            return dns.ipv4.inet_ntoa(parser.get_bytes(4))
-        elif self.type == 2:
-            return dns.ipv6.inet_ntoa(parser.get_bytes(16))
-        elif self.type == 3:
-            return parser.get_name(origin)
+    @classmethod
+    def from_wire_parser(cls, gateway_type, parser, origin=None):
+        if gateway_type == 0:
+            gateway = None
+        elif gateway_type == 1:
+            gateway = dns.ipv4.inet_ntoa(parser.get_bytes(4))
+        elif gateway_type == 2:
+            gateway = dns.ipv6.inet_ntoa(parser.get_bytes(16))
+        elif gateway_type == 3:
+            gateway = parser.get_name(origin)
         else:
-            raise dns.exception.FormError(self._invalid_type())
+            raise dns.exception.FormError(cls._invalid_type(gateway_type))
+        return cls(gateway_type, gateway)
+
 
 class Bitmap:
     """A helper class for the NSEC/NSEC3/CSYNC type bitmaps"""
index 956bec0d9682b1fd3e4f3b368a953dba91c00813..72ed15a1f8f4dab3ebe8b2fa8902ba9e5670b9c4 100644 (file)
@@ -642,29 +642,18 @@ class RdataTestCase(unittest.TestCase):
 class UtilTestCase(unittest.TestCase):
 
     def test_Gateway_bad_type0(self):
-        g = dns.rdtypes.util.Gateway(0, 'bad.')
         with self.assertRaises(SyntaxError):
-            g.check()
+            dns.rdtypes.util.Gateway(0, 'bad.')
 
     def test_Gateway_bad_type3(self):
-        g = dns.rdtypes.util.Gateway(3, 'bad.')
         with self.assertRaises(SyntaxError):
-            g.check()
+            dns.rdtypes.util.Gateway(3, 'bad.')
 
     def test_Gateway_type4(self):
-        g = dns.rdtypes.util.Gateway(4)
         with self.assertRaises(SyntaxError):
-            g.check()
-        with self.assertRaises(ValueError):
-            g.to_text()
-        with self.assertRaises(dns.exception.SyntaxError):
-            tok = dns.tokenizer.Tokenizer('bogus')
-            g.from_text(tok)
-        with self.assertRaises(ValueError):
-            f = io.BytesIO()
-            g.to_wire(f)
+            dns.rdtypes.util.Gateway(4)
         with self.assertRaises(dns.exception.FormError):
-            g.from_wire_parser(None)
+            dns.rdtypes.util.Gateway.from_wire_parser(4, None)
 
     def test_Bitmap(self):
         b = dns.rdtypes.util.Bitmap