From 903ecff93df2d5836d5051a6ee1ee317c7c459d6 Mon Sep 17 00:00:00 2001 From: "Joshua M. Keyes" Date: Fri, 21 Jan 2022 18:29:32 -0800 Subject: [PATCH] Implement dns.set.Set.symmetric_difference() Support. --- dns/rdataset.py | 3 +++ dns/set.py | 31 +++++++++++++++++++++++++++++++ tests/test_set.py | 24 ++++++++++++++++++++++++ 3 files changed, 58 insertions(+) diff --git a/dns/rdataset.py b/dns/rdataset.py index e69ee232..08f9bf5b 100644 --- a/dns/rdataset.py +++ b/dns/rdataset.py @@ -386,6 +386,9 @@ class ImmutableRdataset(Rdataset): def difference(self, other): return ImmutableRdataset(super().difference(other)) + def symmetric_difference(self, other): + return ImmutableRdataset(super().symmetric_difference(other)) + def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None, origin=None, relativize=True, relativize_to=None): diff --git a/dns/set.py b/dns/set.py index 1fd4d0ae..0ba15dfa 100644 --- a/dns/set.py +++ b/dns/set.py @@ -145,6 +145,18 @@ class Set: for item in other.items: self.discard(item) + def symmetric_difference_update(self, other): + """Update the set, retaining only elements unique to both sets.""" + + if not isinstance(other, Set): + raise ValueError('other must be a Set instance') + if self is other: + self.items.clear() + else: + overlap = self.intersection(other) + self.union_update(other) + self.difference_update(overlap) + def union(self, other): """Return a new set which is the union of ``self`` and ``other``. @@ -177,6 +189,18 @@ class Set: obj.difference_update(other) return obj + def symmetric_difference(self, other): + """Return a new set which (``self`` - ``other``) | (``other`` + - ``self), ie: the items in either ``self`` or ``other`` which + are not contained in their intersection. + + Returns the same Set type as this set. + """ + + obj = self._clone() + obj.symmetric_difference_update(other) + return obj + def __or__(self, other): return self.union(other) @@ -189,6 +213,9 @@ class Set: def __sub__(self, other): return self.difference(other) + def __xor__(self, other): + return self.symmetric_difference(other) + def __ior__(self, other): self.union_update(other) return self @@ -205,6 +232,10 @@ class Set: self.difference_update(other) return self + def __ixor__(self, other): + self.symmetric_difference_update(other) + return self + def update(self, other): """Update the set, adding any elements from other which are not already in the set. diff --git a/tests/test_set.py b/tests/test_set.py index 8019d577..6fd0860b 100644 --- a/tests/test_set.py +++ b/tests/test_set.py @@ -109,6 +109,30 @@ class SetTestCase(unittest.TestCase): e = S([]) self.assertEqual(s1 - s2, e) + def testSymmetricDifference1(self): + s1 = S([1, 2, 3]) + s2 = S([5, 4]) + e = S([1, 2, 3, 4, 5]) + self.assertEqual(s1 ^ s2, e) + + def testSymmetricDifference2(self): + s1 = S([1, 2, 3]) + s2 = S([]) + e = S([1, 2, 3]) + self.assertEqual(s1 ^ s2, e) + + def testSymmetricDifference3(self): + s1 = S([1, 2, 3]) + s2 = S([3, 2]) + e = S([1]) + self.assertEqual(s1 ^ s2, e) + + def testSymmetricDifference4(self): + s1 = S([1, 2, 3]) + s2 = S([3, 2, 1]) + e = S([]) + self.assertEqual(s1 ^ s2, e) + def testSubset1(self): s1 = S([1, 2, 3]) s2 = S([3, 2, 1]) -- 2.47.3