]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Change parameter order of low_level_address_tuple; add test coverage. 504/head
authorBob Halley <halley@dnspython.org>
Sat, 13 Jun 2020 18:40:54 +0000 (11:40 -0700)
committerBob Halley <halley@dnspython.org>
Sat, 13 Jun 2020 18:40:54 +0000 (11:40 -0700)
dns/_curio_backend.py
dns/_trio_backend.py
dns/asyncquery.py
dns/inet.py
tests/test_ntoaaton.py

index 836273b382fa938ab4010a94ab28c4c5f11bb082..d5eba68d8e54ba4494a2c3c78445a65849beb690 100644 (file)
@@ -77,14 +77,14 @@ class Backend(dns._asyncbackend.Backend):
             s = curio.socket.socket(af, socktype, proto)
             try:
                 if source:
-                    s.bind(_lltuple(af, source))
+                    s.bind(_lltuple(source, af))
             except Exception:
                 await s.close()
                 raise
             return DatagramSocket(s)
         elif socktype == socket.SOCK_STREAM:
             if source:
-                source_addr = (_lltuple(af, source))
+                source_addr = _lltuple(source, af)
             else:
                 source_addr = None
             async with _maybe_timeout(timeout):
index 418639cb7786556ecd8e0f45ae6c4f43453133df..cfb0e1d18aad7f911c54428a85ae14d0d8906f31 100644 (file)
@@ -81,10 +81,10 @@ class Backend(dns._asyncbackend.Backend):
         stream = None
         try:
             if source:
-                await s.bind(_lltuple(af, source))
+                await s.bind(_lltuple(source, af))
             if socktype == socket.SOCK_STREAM:
                 with _maybe_timeout(timeout):
-                    await s.connect(_lltuple(af, destination))
+                    await s.connect(_lltuple(destination, af))
         except Exception:
             s.close()
             raise
index 38141feb8d383eec0c2df293a9fd3867283d2f5a..709c246ab75baeccbab039b256215bf092da8f84 100644 (file)
@@ -199,7 +199,7 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
             af = dns.inet.af_for_address(where)
             stuple = _source_tuple(af, source, source_port)
             s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple)
-            destination = _lltuple(af, (where, port))
+            destination = _lltuple((where, port), af)
         await send_udp(s, wire, destination, expiration)
         (r, received_time) = await receive_udp(s, destination, expiration,
                                                ignore_unexpected,
index 7960e9f7dbc92ef8c1241c7e5dfcd810f8eee622..71782ac3d25ee20b3ca71eae5b32f2eeb8810f7e 100644 (file)
@@ -141,15 +141,22 @@ def is_address(text):
             return False
 
 
-def low_level_address_tuple(af, high_tuple):
-    """Given an address family and a "high-level" address tuple, i.e.
+def low_level_address_tuple(high_tuple, af=None):
+    """Given a "high-level" address tuple, i.e.
     an (address, port) return the appropriate "low-level" address tuple
     suitable for use in socket calls.
+
+    If an *af* other than ``None`` is provided, it is assumed the
+    address in the high-level tuple is valid and has that af.  If af
+    is ``None``, then af_for_address will be called.
+
     """
     address, port = high_tuple
-    if af == dns.inet.AF_INET:
+    if af is None:
+        af = af_for_address(address)
+    if af == AF_INET:
         return (address, port)
-    elif af == dns.inet.AF_INET6:
+    elif af == AF_INET6:
         ai_flags = socket.AI_NUMERICHOST
         ((*_, tup), *_) = socket.getaddrinfo(address, port, flags=ai_flags)
         return tup
index 36107e1b5c95e0d8b730173ac391be198f196886..3a72891fb526455a8128468039f51f0491a9f492 100644 (file)
@@ -17,6 +17,7 @@
 
 import unittest
 import binascii
+import socket
 
 import dns.exception
 import dns.ipv4
@@ -274,5 +275,19 @@ class NtoAAtoNTestCase(unittest.TestCase):
                        ('2001:db8:0:1:1:1:1:q1', False)]:
             self.assertEqual(dns.inet.is_address(t), e)
 
+    def test_low_level_address_tuple(self):
+        t = dns.inet.low_level_address_tuple(('1.2.3.4', 53))
+        self.assertEqual(t, ('1.2.3.4', 53))
+        t = dns.inet.low_level_address_tuple(('2600::1', 53))
+        self.assertEqual(t, ('2600::1', 53, 0, 0))
+        t = dns.inet.low_level_address_tuple(('1.2.3.4', 53), socket.AF_INET)
+        self.assertEqual(t, ('1.2.3.4', 53))
+        t = dns.inet.low_level_address_tuple(('2600::1', 53), socket.AF_INET6)
+        self.assertEqual(t, ('2600::1', 53, 0, 0))
+        def bad():
+            bogus = socket.AF_INET + socket.AF_INET6 + 1
+            t = dns.inet.low_level_address_tuple(('2600::1', 53), bogus)
+        self.assertRaises(NotImplementedError, bad)
+
 if __name__ == '__main__':
     unittest.main()