]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Implement dns.set.Set.symmetric_difference() Support. 765/head
authorJoshua M. Keyes <joshua.michael.keyes@gmail.com>
Sat, 22 Jan 2022 02:29:32 +0000 (18:29 -0800)
committerJoshua M. Keyes <joshua.michael.keyes@gmail.com>
Sat, 22 Jan 2022 03:32:33 +0000 (19:32 -0800)
dns/rdataset.py
dns/set.py
tests/test_set.py

index e69ee2325aee6a74675f02bede9d28ca737b0812..08f9bf5b8049464e6f29525a0f9252741442e47b 100644 (file)
@@ -386,6 +386,9 @@ class ImmutableRdataset(Rdataset):
     def difference(self, other):
         return ImmutableRdataset(super().difference(other))
 
+    def symmetric_difference(self, other):
+        return ImmutableRdataset(super().symmetric_difference(other))
+
 
 def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None,
                    origin=None, relativize=True, relativize_to=None):
index 1fd4d0ae79919716e17a3225025a7255aa1ea733..0ba15dfa6124553eda8b6b6ee9930f0f88cca4ca 100644 (file)
@@ -145,6 +145,18 @@ class Set:
             for item in other.items:
                 self.discard(item)
 
+    def symmetric_difference_update(self, other):
+        """Update the set, retaining only elements unique to both sets."""
+
+        if not isinstance(other, Set):
+            raise ValueError('other must be a Set instance')
+        if self is other:
+            self.items.clear()
+        else:
+            overlap = self.intersection(other)
+            self.union_update(other)
+            self.difference_update(overlap)
+
     def union(self, other):
         """Return a new set which is the union of ``self`` and ``other``.
 
@@ -177,6 +189,18 @@ class Set:
         obj.difference_update(other)
         return obj
 
+    def symmetric_difference(self, other):
+        """Return a new set which (``self`` - ``other``) | (``other``
+        - ``self), ie: the items in either ``self`` or ``other`` which
+        are not contained in their intersection.
+
+        Returns the same Set type as this set.
+        """
+
+        obj = self._clone()
+        obj.symmetric_difference_update(other)
+        return obj
+
     def __or__(self, other):
         return self.union(other)
 
@@ -189,6 +213,9 @@ class Set:
     def __sub__(self, other):
         return self.difference(other)
 
+    def __xor__(self, other):
+        return self.symmetric_difference(other)
+
     def __ior__(self, other):
         self.union_update(other)
         return self
@@ -205,6 +232,10 @@ class Set:
         self.difference_update(other)
         return self
 
+    def __ixor__(self, other):
+        self.symmetric_difference_update(other)
+        return self
+
     def update(self, other):
         """Update the set, adding any elements from other which are not
         already in the set.
index 8019d577b326282f069d5cfe0cbe0c59c11f4e27..6fd0860b346388d9b8b60500815bb93a61799941 100644 (file)
@@ -109,6 +109,30 @@ class SetTestCase(unittest.TestCase):
         e = S([])
         self.assertEqual(s1 - s2, e)
 
+    def testSymmetricDifference1(self):
+        s1 = S([1, 2, 3])
+        s2 = S([5, 4])
+        e = S([1, 2, 3, 4, 5])
+        self.assertEqual(s1 ^ s2, e)
+
+    def testSymmetricDifference2(self):
+        s1 = S([1, 2, 3])
+        s2 = S([])
+        e = S([1, 2, 3])
+        self.assertEqual(s1 ^ s2, e)
+
+    def testSymmetricDifference3(self):
+        s1 = S([1, 2, 3])
+        s2 = S([3, 2])
+        e = S([1])
+        self.assertEqual(s1 ^ s2, e)
+
+    def testSymmetricDifference4(self):
+        s1 = S([1, 2, 3])
+        s2 = S([3, 2, 1])
+        e = S([])
+        self.assertEqual(s1 ^ s2, e)
+
     def testSubset1(self):
         s1 = S([1, 2, 3])
         s2 = S([3, 2, 1])