]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Issue 14814: Ensure ordering semantics across all 3 entity types in ipaddress are...
authorNick Coghlan <ncoghlan@gmail.com>
Sat, 7 Jul 2012 14:45:33 +0000 (00:45 +1000)
committerNick Coghlan <ncoghlan@gmail.com>
Sat, 7 Jul 2012 14:45:33 +0000 (00:45 +1000)
Lib/ipaddress.py
Lib/test/test_ipaddress.py
Misc/NEWS

index b1e07fc992e5edecbcbe37428e04e9047cf6fccd..201900955172f036da020ae467d9a668aff8621e 100644 (file)
@@ -12,7 +12,7 @@ __version__ = '1.0'
 
 
 import struct
-
+import functools
 
 IPV4LENGTH = 32
 IPV6LENGTH = 128
@@ -405,7 +405,38 @@ def get_mixed_type_key(obj):
     return NotImplemented
 
 
-class _IPAddressBase:
+class _TotalOrderingMixin:
+    # Helper that derives the other comparison operations from
+    # __lt__ and __eq__
+    def __eq__(self, other):
+        raise NotImplementedError
+    def __ne__(self, other):
+        equal = self.__eq__(other)
+        if equal is NotImplemented:
+            return NotImplemented
+        return not equal
+    def __lt__(self, other):
+        raise NotImplementedError
+    def __le__(self, other):
+        less = self.__lt__(other)
+        if less is NotImplemented or not less:
+            return self.__eq__(other)
+        return less
+    def __gt__(self, other):
+        less = self.__lt__(other)
+        if less is NotImplemented:
+            return NotImplemented
+        equal = self.__eq__(other)
+        if equal is NotImplemented:
+            return NotImplemented
+        return not (less or equal)
+    def __ge__(self, other):
+        less = self.__lt__(other)
+        if less is NotImplemented:
+            return NotImplemented
+        return not less
+
+class _IPAddressBase(_TotalOrderingMixin):
 
     """The mother class."""
 
@@ -465,7 +496,6 @@ class _IPAddressBase:
             prefixlen = self._prefixlen
         return self._string_from_ip_int(self._ip_int_from_prefix(prefixlen))
 
-
 class _BaseAddress(_IPAddressBase):
 
     """A generic IP object.
@@ -493,24 +523,6 @@ class _BaseAddress(_IPAddressBase):
         except AttributeError:
             return NotImplemented
 
-    def __ne__(self, other):
-        eq = self.__eq__(other)
-        if eq is NotImplemented:
-            return NotImplemented
-        return not eq
-
-    def __le__(self, other):
-        gt = self.__gt__(other)
-        if gt is NotImplemented:
-            return NotImplemented
-        return not gt
-
-    def __ge__(self, other):
-        lt = self.__lt__(other)
-        if lt is NotImplemented:
-            return NotImplemented
-        return not lt
-
     def __lt__(self, other):
         if self._version != other._version:
             raise TypeError('%s and %s are not of the same version' % (
@@ -522,17 +534,6 @@ class _BaseAddress(_IPAddressBase):
             return self._ip < other._ip
         return False
 
-    def __gt__(self, other):
-        if self._version != other._version:
-            raise TypeError('%s and %s are not of the same version' % (
-                             self, other))
-        if not isinstance(other, _BaseAddress):
-            raise TypeError('%s and %s are not of the same type' % (
-                             self, other))
-        if self._ip != other._ip:
-            return self._ip > other._ip
-        return False
-
     # Shorthand for Integer addition and subtraction. This is not
     # meant to ever support addition/subtraction of addresses.
     def __add__(self, other):
@@ -625,31 +626,6 @@ class _BaseNetwork(_IPAddressBase):
             return self.netmask < other.netmask
         return False
 
-    def __gt__(self, other):
-        if self._version != other._version:
-            raise TypeError('%s and %s are not of the same version' % (
-                             self, other))
-        if not isinstance(other, _BaseNetwork):
-            raise TypeError('%s and %s are not of the same type' % (
-                             self, other))
-        if self.network_address != other.network_address:
-            return self.network_address > other.network_address
-        if self.netmask != other.netmask:
-            return self.netmask > other.netmask
-        return False
-
-    def __le__(self, other):
-        gt = self.__gt__(other)
-        if gt is NotImplemented:
-            return NotImplemented
-        return not gt
-
-    def __ge__(self, other):
-        lt = self.__lt__(other)
-        if lt is NotImplemented:
-            return NotImplemented
-        return not lt
-
     def __eq__(self, other):
         try:
             return (self._version == other._version and
@@ -658,12 +634,6 @@ class _BaseNetwork(_IPAddressBase):
         except AttributeError:
             return NotImplemented
 
-    def __ne__(self, other):
-        eq = self.__eq__(other)
-        if eq is NotImplemented:
-            return NotImplemented
-        return not eq
-
     def __hash__(self):
         return hash(int(self.network_address) ^ int(self.netmask))
 
@@ -1292,11 +1262,27 @@ class IPv4Interface(IPv4Address):
                           self.network.prefixlen)
 
     def __eq__(self, other):
+        address_equal = IPv4Address.__eq__(self, other)
+        if not address_equal or address_equal is NotImplemented:
+            return address_equal
         try:
-            return (IPv4Address.__eq__(self, other) and
-                    self.network == other.network)
+            return self.network == other.network
         except AttributeError:
+            # An interface with an associated network is NOT the
+            # same as an unassociated address. That's why the hash
+            # takes the extra info into account.
+            return False
+
+    def __lt__(self, other):
+        address_less = IPv4Address.__lt__(self, other)
+        if address_less is NotImplemented:
             return NotImplemented
+        try:
+            return self.network < other.network
+        except AttributeError:
+            # We *do* allow addresses and interfaces to be sorted. The
+            # unassociated address is considered less than all interfaces.
+            return False
 
     def __hash__(self):
         return self._ip ^ self._prefixlen ^ int(self.network.network_address)
@@ -1928,11 +1914,27 @@ class IPv6Interface(IPv6Address):
                           self.network.prefixlen)
 
     def __eq__(self, other):
+        address_equal = IPv6Address.__eq__(self, other)
+        if not address_equal or address_equal is NotImplemented:
+            return address_equal
         try:
-            return (IPv6Address.__eq__(self, other) and
-                    self.network == other.network)
+            return self.network == other.network
         except AttributeError:
+            # An interface with an associated network is NOT the
+            # same as an unassociated address. That's why the hash
+            # takes the extra info into account.
+            return False
+
+    def __lt__(self, other):
+        address_less = IPv6Address.__lt__(self, other)
+        if address_less is NotImplemented:
             return NotImplemented
+        try:
+            return self.network < other.network
+        except AttributeError:
+            # We *do* allow addresses and interfaces to be sorted. The
+            # unassociated address is considered less than all interfaces.
+            return False
 
     def __hash__(self):
         return self._ip ^ self._prefixlen ^ int(self.network.network_address)
index 417c98677f00d45c7cf9d244b6dcd4fd01347fa4..5aaf73674016c92b6a797d040d4db444f3557ecb 100644 (file)
@@ -415,6 +415,93 @@ class FactoryFunctionErrors(ErrorReporting):
         self.assertFactoryError(ipaddress.ip_network, "network")
 
 
+class ComparisonTests(unittest.TestCase):
+
+    v4addr = ipaddress.IPv4Address(1)
+    v4net = ipaddress.IPv4Network(1)
+    v4intf = ipaddress.IPv4Interface(1)
+    v6addr = ipaddress.IPv6Address(1)
+    v6net = ipaddress.IPv6Network(1)
+    v6intf = ipaddress.IPv6Interface(1)
+
+    v4_addresses = [v4addr, v4intf]
+    v4_objects = v4_addresses + [v4net]
+    v6_addresses = [v6addr, v6intf]
+    v6_objects = v6_addresses + [v6net]
+    objects = v4_objects + v6_objects
+
+    def test_foreign_type_equality(self):
+        # __eq__ should never raise TypeError directly
+        other = object()
+        for obj in self.objects:
+            self.assertNotEqual(obj, other)
+            self.assertFalse(obj == other)
+            self.assertEqual(obj.__eq__(other), NotImplemented)
+            self.assertEqual(obj.__ne__(other), NotImplemented)
+
+    def test_mixed_type_equality(self):
+        # Ensure none of the internal objects accidentally
+        # expose the right set of attributes to become "equal"
+        for lhs in self.objects:
+            for rhs in self.objects:
+                if lhs is rhs:
+                    continue
+                self.assertNotEqual(lhs, rhs)
+
+    def test_containment(self):
+        for obj in self.v4_addresses:
+            self.assertIn(obj, self.v4net)
+        for obj in self.v6_addresses:
+            self.assertIn(obj, self.v6net)
+        for obj in self.v4_objects + [self.v6net]:
+            self.assertNotIn(obj, self.v6net)
+        for obj in self.v6_objects + [self.v4net]:
+            self.assertNotIn(obj, self.v4net)
+
+    def test_mixed_type_ordering(self):
+        for lhs in self.objects:
+            for rhs in self.objects:
+                if isinstance(lhs, type(rhs)) or isinstance(rhs, type(lhs)):
+                    continue
+                self.assertRaises(TypeError, lambda: lhs < rhs)
+                self.assertRaises(TypeError, lambda: lhs > rhs)
+                self.assertRaises(TypeError, lambda: lhs <= rhs)
+                self.assertRaises(TypeError, lambda: lhs >= rhs)
+
+    def test_mixed_type_key(self):
+        # with get_mixed_type_key, you can sort addresses and network.
+        v4_ordered = [self.v4addr, self.v4net, self.v4intf]
+        v6_ordered = [self.v6addr, self.v6net, self.v6intf]
+        self.assertEqual(v4_ordered,
+                         sorted(self.v4_objects,
+                                key=ipaddress.get_mixed_type_key))
+        self.assertEqual(v6_ordered,
+                         sorted(self.v6_objects,
+                                key=ipaddress.get_mixed_type_key))
+        self.assertEqual(v4_ordered + v6_ordered,
+                         sorted(self.objects,
+                                key=ipaddress.get_mixed_type_key))
+        self.assertEqual(NotImplemented, ipaddress.get_mixed_type_key(object))
+
+    def test_incompatible_versions(self):
+        # These should always raise TypeError
+        v4addr = ipaddress.ip_address('1.1.1.1')
+        v4net = ipaddress.ip_network('1.1.1.1')
+        v6addr = ipaddress.ip_address('::1')
+        v6net = ipaddress.ip_address('::1')
+
+        self.assertRaises(TypeError, v4addr.__lt__, v6addr)
+        self.assertRaises(TypeError, v4addr.__gt__, v6addr)
+        self.assertRaises(TypeError, v4net.__lt__, v6net)
+        self.assertRaises(TypeError, v4net.__gt__, v6net)
+
+        self.assertRaises(TypeError, v6addr.__lt__, v4addr)
+        self.assertRaises(TypeError, v6addr.__gt__, v4addr)
+        self.assertRaises(TypeError, v6net.__lt__, v4net)
+        self.assertRaises(TypeError, v6net.__gt__, v4net)
+
+
+
 class IpaddrUnitTest(unittest.TestCase):
 
     def setUp(self):
@@ -495,67 +582,6 @@ class IpaddrUnitTest(unittest.TestCase):
         self.assertEqual(str(self.ipv6_network.hostmask),
                          '::ffff:ffff:ffff:ffff')
 
-    def testEqualityChecks(self):
-        # __eq__ should never raise TypeError directly
-        other = object()
-        def assertEqualityNotImplemented(instance):
-            self.assertEqual(instance.__eq__(other), NotImplemented)
-            self.assertEqual(instance.__ne__(other), NotImplemented)
-            self.assertFalse(instance == other)
-            self.assertTrue(instance != other)
-
-        assertEqualityNotImplemented(self.ipv4_address)
-        assertEqualityNotImplemented(self.ipv4_network)
-        assertEqualityNotImplemented(self.ipv4_interface)
-        assertEqualityNotImplemented(self.ipv6_address)
-        assertEqualityNotImplemented(self.ipv6_network)
-        assertEqualityNotImplemented(self.ipv6_interface)
-
-    def testBadVersionComparison(self):
-        # These should always raise TypeError
-        v4addr = ipaddress.ip_address('1.1.1.1')
-        v4net = ipaddress.ip_network('1.1.1.1')
-        v6addr = ipaddress.ip_address('::1')
-        v6net = ipaddress.ip_address('::1')
-
-        self.assertRaises(TypeError, v4addr.__lt__, v6addr)
-        self.assertRaises(TypeError, v4addr.__gt__, v6addr)
-        self.assertRaises(TypeError, v4net.__lt__, v6net)
-        self.assertRaises(TypeError, v4net.__gt__, v6net)
-
-        self.assertRaises(TypeError, v6addr.__lt__, v4addr)
-        self.assertRaises(TypeError, v6addr.__gt__, v4addr)
-        self.assertRaises(TypeError, v6net.__lt__, v4net)
-        self.assertRaises(TypeError, v6net.__gt__, v4net)
-
-    def testMixedTypeComparison(self):
-        v4addr = ipaddress.ip_address('1.1.1.1')
-        v4net = ipaddress.ip_network('1.1.1.1/32')
-        v6addr = ipaddress.ip_address('::1')
-        v6net = ipaddress.ip_network('::1/128')
-
-        self.assertFalse(v4net.__contains__(v6net))
-        self.assertFalse(v6net.__contains__(v4net))
-
-        self.assertRaises(TypeError, lambda: v4addr < v4net)
-        self.assertRaises(TypeError, lambda: v4addr > v4net)
-        self.assertRaises(TypeError, lambda: v4net < v4addr)
-        self.assertRaises(TypeError, lambda: v4net > v4addr)
-
-        self.assertRaises(TypeError, lambda: v6addr < v6net)
-        self.assertRaises(TypeError, lambda: v6addr > v6net)
-        self.assertRaises(TypeError, lambda: v6net < v6addr)
-        self.assertRaises(TypeError, lambda: v6net > v6addr)
-
-        # with get_mixed_type_key, you can sort addresses and network.
-        self.assertEqual([v4addr, v4net],
-                         sorted([v4net, v4addr],
-                                key=ipaddress.get_mixed_type_key))
-        self.assertEqual([v6addr, v6net],
-                         sorted([v6net, v6addr],
-                                key=ipaddress.get_mixed_type_key))
-        self.assertEqual(NotImplemented, ipaddress.get_mixed_type_key(object))
-
     def testIpFromInt(self):
         self.assertEqual(self.ipv4_interface._ip,
                          ipaddress.IPv4Interface(16909060)._ip)
@@ -1049,6 +1075,16 @@ class IpaddrUnitTest(unittest.TestCase):
         self.assertTrue(ipaddress.ip_address('::1') <=
                         ipaddress.ip_address('::2'))
 
+    def testInterfaceComparison(self):
+        self.assertTrue(ipaddress.ip_interface('1.1.1.1') <=
+                        ipaddress.ip_interface('1.1.1.1'))
+        self.assertTrue(ipaddress.ip_interface('1.1.1.1') <=
+                        ipaddress.ip_interface('1.1.1.2'))
+        self.assertTrue(ipaddress.ip_interface('::1') <=
+                        ipaddress.ip_interface('::1'))
+        self.assertTrue(ipaddress.ip_interface('::1') <=
+                        ipaddress.ip_interface('::2'))
+
     def testNetworkComparison(self):
         # ip1 and ip2 have the same network address
         ip1 = ipaddress.IPv4Network('1.1.1.0/24')
index a109baf039fe9e707cbba424f02c6a9f8bc5a50a..bb0211a45143dc09abfc42aeac32791ac95073fd 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -23,6 +23,9 @@ Core and Builtins
 Library
 -------
 
+- Issue #14814: implement more consistent ordering and sorting behaviour
+  for ipaddress objects
+
 - Issue #14814: ipaddress network objects correctly return NotImplemented
   when compared to arbitrary objects instead of raising TypeError