]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add support for receiving UDP queries.
authorBrian Wellington <bwelling@xbill.org>
Wed, 8 Jul 2020 22:11:19 +0000 (15:11 -0700)
committerBrian Wellington <bwelling@xbill.org>
Wed, 8 Jul 2020 22:15:31 +0000 (15:15 -0700)
The existing receive_udp() methods are only usable for receiving
responses, as they require an expected destination and check that the
message is from that destination.

This change makes the expected destination (and hence the check)
optional, and returns the address that the message was received from (in
the sync case, this is only done if no destination is provided, for
backwards compatibility).

New tests are added, which required adding generic getsockname() support
to the async backends.

dns/_asyncio_backend.py
dns/_curio_backend.py
dns/_trio_backend.py
dns/asyncquery.py
dns/query.py
tests/test_async.py
tests/test_query.py

index ba7c2e72ccacfaa25bb130078c923307cb3a8da9..3af34ff89cc4a7def762a3c0e8a6db98da5e8707 100644 (file)
@@ -75,6 +75,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
     async def getpeername(self):
         return self.transport.get_extra_info('peername')
 
+    async def getsockname(self):
+        return self.transport.get_extra_info('sockname')
+
 
 class StreamSocket(dns._asyncbackend.DatagramSocket):
     def __init__(self, af, reader, writer):
@@ -102,6 +105,9 @@ class StreamSocket(dns._asyncbackend.DatagramSocket):
     async def getpeername(self):
         return self.writer.get_extra_info('peername')
 
+    async def getsockname(self):
+        return self.writer.get_extra_info('sockname')
+
 
 class Backend(dns._asyncbackend.Backend):
     def name(self):
index dca966df56865e4a9dbe235376a0b87c433e2a7d..300e1b89e409f71d125c69da0b7637ae1891a7e4 100644 (file)
@@ -43,6 +43,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
     async def getpeername(self):
         return self.socket.getpeername()
 
+    async def getsockname(self):
+        return self.socket.getsockname()
+
 
 class StreamSocket(dns._asyncbackend.DatagramSocket):
     def __init__(self, socket):
@@ -65,6 +68,9 @@ class StreamSocket(dns._asyncbackend.DatagramSocket):
     async def getpeername(self):
         return self.socket.getpeername()
 
+    async def getsockname(self):
+        return self.socket.getsockname()
+
 
 class Backend(dns._asyncbackend.Backend):
     def name(self):
index 0f1378f319b2a05c917f0d89f06527e3c47cfcfd..92ea87960cd9a042fc6755fd86807abac92b5fa0 100644 (file)
@@ -43,6 +43,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
     async def getpeername(self):
         return self.socket.getpeername()
 
+    async def getsockname(self):
+        return self.socket.getsockname()
+
 
 class StreamSocket(dns._asyncbackend.DatagramSocket):
     def __init__(self, family, stream, tls=False):
@@ -69,6 +72,12 @@ class StreamSocket(dns._asyncbackend.DatagramSocket):
         else:
             return self.stream.socket.getpeername()
 
+    async def getsockname(self):
+        if self.tls:
+            return self.stream.transport_stream.socket.getsockname()
+        else:
+            return self.stream.socket.getsockname()
+
 
 class Backend(dns._asyncbackend.Backend):
     def name(self):
index 4afe7bcc8813eb071c265bd86907b750d76de211..b792648067a4f3a1bc36643db6b90c57d1e08651 100644 (file)
@@ -30,8 +30,7 @@ import dns.rcode
 import dns.rdataclass
 import dns.rdatatype
 
-from dns.query import _addresses_equal, _compute_times, UnexpectedSource, \
-    BadResponse, ssl
+from dns.query import _compute_times, _matches_destination, BadResponse, ssl
 
 
 # for brevity
@@ -87,7 +86,7 @@ async def send_udp(sock, what, destination, expiration=None):
     return (n, sent_time)
 
 
-async def receive_udp(sock, destination, expiration=None,
+async def receive_udp(sock, destination=None, expiration=None,
                       ignore_unexpected=False, one_rr_per_rrset=False,
                       keyring=None, request_mac=b'', ignore_trailing=False,
                       raise_on_truncation=False):
@@ -96,7 +95,9 @@ async def receive_udp(sock, destination, expiration=None,
     *sock*, a ``dns.asyncbackend.DatagramSocket``.
 
     *destination*, a destination tuple appropriate for the address family
-    of the socket, specifying where the associated query was sent.
+    of the socket, specifying where the message is expected to arrive from.
+    When receiving a response, this would be where the associated query was
+    sent.
 
     *expiration*, a ``float`` or ``None``, the absolute time at which
     a timeout exception should be raised.  If ``None``, no timeout will
@@ -121,27 +122,22 @@ async def receive_udp(sock, destination, expiration=None,
     Raises if the message is malformed, if network errors occur, of if
     there is a timeout.
 
-    Returns a ``(dns.message.Message, float)`` tuple of the received message
-    and the received time.
+    Returns a ``(dns.message.Message, float, tuple)`` tuple of the received
+    message, the received time, and the address where the message arrived from.
     """
 
     wire = b''
     while 1:
         (wire, from_address) = await sock.recvfrom(65535, _timeout(expiration))
-        if _addresses_equal(sock.family, from_address, destination) or \
-           (dns.inet.is_multicast(destination[0]) and
-            from_address[1:] == destination[1:]):
+        if _matches_destination(sock.family, from_address, destination,
+                                ignore_unexpected):
             break
-        if not ignore_unexpected:
-            raise UnexpectedSource('got a response from '
-                                   '%s instead of %s' % (from_address,
-                                                         destination))
     received_time = time.time()
     r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
                               one_rr_per_rrset=one_rr_per_rrset,
                               ignore_trailing=ignore_trailing,
                               raise_on_truncation=raise_on_truncation)
-    return (r, received_time)
+    return (r, received_time, from_address)
 
 async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
               ignore_unexpected=False, one_rr_per_rrset=False,
@@ -202,12 +198,12 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
             stuple = _source_tuple(af, source, source_port)
             s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple)
         await send_udp(s, wire, destination, expiration)
-        (r, received_time) = await receive_udp(s, destination, expiration,
-                                               ignore_unexpected,
-                                               one_rr_per_rrset,
-                                               q.keyring, q.mac,
-                                               ignore_trailing,
-                                               raise_on_truncation)
+        (r, received_time, _) = await receive_udp(s, destination, expiration,
+                                                  ignore_unexpected,
+                                                  one_rr_per_rrset,
+                                                  q.keyring, q.mac,
+                                                  ignore_trailing,
+                                                  raise_on_truncation)
         r.time = received_time - begin_time
         if not q.is_response(r):
             raise BadResponse
index 13c824614e4dc4778287a12f5165c7a7c4a4be28..7df565d851f2c07aeb8babdbe5195353c8d529a8 100644 (file)
@@ -201,6 +201,21 @@ def _addresses_equal(af, a1, a2):
     return n1 == n2 and a1[1:] == a2[1:]
 
 
+def _matches_destination(af, from_address, destination, ignore_unexpected):
+    # Check that from_address is appropriate for a response to a query
+    # sent to destination.
+    if not destination:
+        return True
+    if _addresses_equal(af, from_address, destination) or \
+       (dns.inet.is_multicast(destination[0]) and
+        from_address[1:] == destination[1:]):
+        return True
+    elif ignore_unexpected:
+        return False
+    raise UnexpectedSource(f'got a response from {from_address} instead of '
+                           f'{destination}')
+
+
 def _destination_and_source(where, port, source, source_port,
                             where_must_be_address=True):
     # Apply defaults and compute destination and source tuples
@@ -397,7 +412,7 @@ def send_udp(sock, what, destination, expiration=None):
     return (n, sent_time)
 
 
-def receive_udp(sock, destination, expiration=None,
+def receive_udp(sock, destination=None, expiration=None,
                 ignore_unexpected=False, one_rr_per_rrset=False,
                 keyring=None, request_mac=b'', ignore_trailing=False,
                 raise_on_truncation=False):
@@ -406,7 +421,9 @@ def receive_udp(sock, destination, expiration=None,
     *sock*, a ``socket``.
 
     *destination*, a destination tuple appropriate for the address family
-    of the socket, specifying where the associated query was sent.
+    of the socket, specifying where the message is expected to arrive from.
+    When receiving a response, this would be where the associated query was
+    sent.
 
     *expiration*, a ``float`` or ``None``, the absolute time at which
     a timeout exception should be raised.  If ``None``, no timeout will
@@ -431,28 +448,31 @@ def receive_udp(sock, destination, expiration=None,
     Raises if the message is malformed, if network errors occur, of if
     there is a timeout.
 
-    Returns a ``(dns.message.Message, float)`` tuple of the received message
-    and the received time.
+    If *destination* is not ``None``, returns a ``(dns.message.Message, float)``
+    tuple of the received message and the received time.
+
+    If *destination* is ``None``, returns a
+    ``(dns.message.Message, float, tuple)``
+    tuple of the received message, the received time, and the address where
+    the message arrived from.
     """
 
     wire = b''
     while 1:
         _wait_for_readable(sock, expiration)
         (wire, from_address) = sock.recvfrom(65535)
-        if _addresses_equal(sock.family, from_address, destination) or \
-           (dns.inet.is_multicast(destination[0]) and
-            from_address[1:] == destination[1:]):
+        if _matches_destination(sock.family, from_address, destination,
+                                ignore_unexpected):
             break
-        if not ignore_unexpected:
-            raise UnexpectedSource('got a response from '
-                                   '%s instead of %s' % (from_address,
-                                                         destination))
     received_time = time.time()
     r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
                               one_rr_per_rrset=one_rr_per_rrset,
                               ignore_trailing=ignore_trailing,
                               raise_on_truncation=raise_on_truncation)
-    return (r, received_time)
+    if destination:
+        return (r, received_time)
+    else:
+        return (r, received_time, from_address)
 
 def udp(q, where, timeout=None, port=53, source=None, source_port=0,
         ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,
index 2d2543400fef3e8d5ed9aee0fc15e2e8e1b4cfcc..5faaa6e9e01befce215f13c0d0c45878756b5cf6 100644 (file)
@@ -343,6 +343,25 @@ class AsyncTests(unittest.TestCase):
             (_, tcp) = self.async_run(run)
             self.assertFalse(tcp)
 
+    def testUDPReceiveQuery(self):
+        async def run():
+            async with await self.backend.make_socket(
+                    socket.AF_INET, socket.SOCK_DGRAM,
+                    source=('127.0.0.1', 0)) as listener:
+                listener_address = await listener.getsockname()
+                async with await self.backend.make_socket(
+                        socket.AF_INET, socket.SOCK_DGRAM,
+                        source=('127.0.0.1', 0)) as sender:
+                    sender_address = await sender.getsockname()
+                    q = dns.message.make_query('dns.google', dns.rdatatype.A)
+                    await dns.asyncquery.send_udp(sender, q, listener_address)
+                    expiration = time.time() + 2
+                    (_, _, recv_address) = await dns.asyncquery.receive_udp(
+                            listener, expiration=expiration)
+                    return (sender_address, recv_address)
+        (sender_address, recv_address) = self.async_run(run)
+        self.assertEqual(sender_address, recv_address)
+
     def testUDPReceiveTimeout(self):
         async def arun():
             async with await self.backend.make_socket(socket.AF_INET,
index f1ec55cafa68a9bcfdfc7366f8716277898f3837..498128d2f9bb016a319c1b64e4ba6185a8cc73af 100644 (file)
@@ -191,6 +191,18 @@ class QueryTests(unittest.TestCase):
             (_, tcp) = dns.query.udp_with_fallback(q, address)
             self.assertFalse(tcp)
 
+    def testUDPReceiveQuery(self):
+        with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as listener:
+            listener.bind(('127.0.0.1', 0))
+            with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sender:
+                sender.bind(('127.0.0.1', 0))
+                q = dns.message.make_query('dns.google', dns.rdatatype.A)
+                dns.query.send_udp(sender, q, listener.getsockname())
+                expiration = time.time() + 2
+                (q, _, addr) = dns.query.receive_udp(listener,
+                                                     expiration=expiration)
+                self.assertEqual(addr, sender.getsockname())
+
 
 # for brevity
 _d_and_s = dns.query._destination_and_source