From: Brian Wellington Date: Wed, 8 Jul 2020 22:11:19 +0000 (-0700) Subject: Add support for receiving UDP queries. X-Git-Tag: v2.0.0~7^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7a5e59707b395454db2cb650371bbc2e800e7be4;p=thirdparty%2Fdnspython.git Add support for receiving UDP queries. The existing receive_udp() methods are only usable for receiving responses, as they require an expected destination and check that the message is from that destination. This change makes the expected destination (and hence the check) optional, and returns the address that the message was received from (in the sync case, this is only done if no destination is provided, for backwards compatibility). New tests are added, which required adding generic getsockname() support to the async backends. --- diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py index ba7c2e72..3af34ff8 100644 --- a/dns/_asyncio_backend.py +++ b/dns/_asyncio_backend.py @@ -75,6 +75,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): async def getpeername(self): return self.transport.get_extra_info('peername') + async def getsockname(self): + return self.transport.get_extra_info('sockname') + class StreamSocket(dns._asyncbackend.DatagramSocket): def __init__(self, af, reader, writer): @@ -102,6 +105,9 @@ class StreamSocket(dns._asyncbackend.DatagramSocket): async def getpeername(self): return self.writer.get_extra_info('peername') + async def getsockname(self): + return self.writer.get_extra_info('sockname') + class Backend(dns._asyncbackend.Backend): def name(self): diff --git a/dns/_curio_backend.py b/dns/_curio_backend.py index dca966df..300e1b89 100644 --- a/dns/_curio_backend.py +++ b/dns/_curio_backend.py @@ -43,6 +43,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): async def getpeername(self): return self.socket.getpeername() + async def getsockname(self): + return self.socket.getsockname() + class StreamSocket(dns._asyncbackend.DatagramSocket): def __init__(self, socket): @@ -65,6 +68,9 @@ class StreamSocket(dns._asyncbackend.DatagramSocket): async def getpeername(self): return self.socket.getpeername() + async def getsockname(self): + return self.socket.getsockname() + class Backend(dns._asyncbackend.Backend): def name(self): diff --git a/dns/_trio_backend.py b/dns/_trio_backend.py index 0f1378f3..92ea8796 100644 --- a/dns/_trio_backend.py +++ b/dns/_trio_backend.py @@ -43,6 +43,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): async def getpeername(self): return self.socket.getpeername() + async def getsockname(self): + return self.socket.getsockname() + class StreamSocket(dns._asyncbackend.DatagramSocket): def __init__(self, family, stream, tls=False): @@ -69,6 +72,12 @@ class StreamSocket(dns._asyncbackend.DatagramSocket): else: return self.stream.socket.getpeername() + async def getsockname(self): + if self.tls: + return self.stream.transport_stream.socket.getsockname() + else: + return self.stream.socket.getsockname() + class Backend(dns._asyncbackend.Backend): def name(self): diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 4afe7bcc..b7926480 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -30,8 +30,7 @@ import dns.rcode import dns.rdataclass import dns.rdatatype -from dns.query import _addresses_equal, _compute_times, UnexpectedSource, \ - BadResponse, ssl +from dns.query import _compute_times, _matches_destination, BadResponse, ssl # for brevity @@ -87,7 +86,7 @@ async def send_udp(sock, what, destination, expiration=None): return (n, sent_time) -async def receive_udp(sock, destination, expiration=None, +async def receive_udp(sock, destination=None, expiration=None, ignore_unexpected=False, one_rr_per_rrset=False, keyring=None, request_mac=b'', ignore_trailing=False, raise_on_truncation=False): @@ -96,7 +95,9 @@ async def receive_udp(sock, destination, expiration=None, *sock*, a ``dns.asyncbackend.DatagramSocket``. *destination*, a destination tuple appropriate for the address family - of the socket, specifying where the associated query was sent. + of the socket, specifying where the message is expected to arrive from. + When receiving a response, this would be where the associated query was + sent. *expiration*, a ``float`` or ``None``, the absolute time at which a timeout exception should be raised. If ``None``, no timeout will @@ -121,27 +122,22 @@ async def receive_udp(sock, destination, expiration=None, Raises if the message is malformed, if network errors occur, of if there is a timeout. - Returns a ``(dns.message.Message, float)`` tuple of the received message - and the received time. + Returns a ``(dns.message.Message, float, tuple)`` tuple of the received + message, the received time, and the address where the message arrived from. """ wire = b'' while 1: (wire, from_address) = await sock.recvfrom(65535, _timeout(expiration)) - if _addresses_equal(sock.family, from_address, destination) or \ - (dns.inet.is_multicast(destination[0]) and - from_address[1:] == destination[1:]): + if _matches_destination(sock.family, from_address, destination, + ignore_unexpected): break - if not ignore_unexpected: - raise UnexpectedSource('got a response from ' - '%s instead of %s' % (from_address, - destination)) received_time = time.time() r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, one_rr_per_rrset=one_rr_per_rrset, ignore_trailing=ignore_trailing, raise_on_truncation=raise_on_truncation) - return (r, received_time) + return (r, received_time, from_address) async def udp(q, where, timeout=None, port=53, source=None, source_port=0, ignore_unexpected=False, one_rr_per_rrset=False, @@ -202,12 +198,12 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0, stuple = _source_tuple(af, source, source_port) s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple) await send_udp(s, wire, destination, expiration) - (r, received_time) = await receive_udp(s, destination, expiration, - ignore_unexpected, - one_rr_per_rrset, - q.keyring, q.mac, - ignore_trailing, - raise_on_truncation) + (r, received_time, _) = await receive_udp(s, destination, expiration, + ignore_unexpected, + one_rr_per_rrset, + q.keyring, q.mac, + ignore_trailing, + raise_on_truncation) r.time = received_time - begin_time if not q.is_response(r): raise BadResponse diff --git a/dns/query.py b/dns/query.py index 13c82461..7df565d8 100644 --- a/dns/query.py +++ b/dns/query.py @@ -201,6 +201,21 @@ def _addresses_equal(af, a1, a2): return n1 == n2 and a1[1:] == a2[1:] +def _matches_destination(af, from_address, destination, ignore_unexpected): + # Check that from_address is appropriate for a response to a query + # sent to destination. + if not destination: + return True + if _addresses_equal(af, from_address, destination) or \ + (dns.inet.is_multicast(destination[0]) and + from_address[1:] == destination[1:]): + return True + elif ignore_unexpected: + return False + raise UnexpectedSource(f'got a response from {from_address} instead of ' + f'{destination}') + + def _destination_and_source(where, port, source, source_port, where_must_be_address=True): # Apply defaults and compute destination and source tuples @@ -397,7 +412,7 @@ def send_udp(sock, what, destination, expiration=None): return (n, sent_time) -def receive_udp(sock, destination, expiration=None, +def receive_udp(sock, destination=None, expiration=None, ignore_unexpected=False, one_rr_per_rrset=False, keyring=None, request_mac=b'', ignore_trailing=False, raise_on_truncation=False): @@ -406,7 +421,9 @@ def receive_udp(sock, destination, expiration=None, *sock*, a ``socket``. *destination*, a destination tuple appropriate for the address family - of the socket, specifying where the associated query was sent. + of the socket, specifying where the message is expected to arrive from. + When receiving a response, this would be where the associated query was + sent. *expiration*, a ``float`` or ``None``, the absolute time at which a timeout exception should be raised. If ``None``, no timeout will @@ -431,28 +448,31 @@ def receive_udp(sock, destination, expiration=None, Raises if the message is malformed, if network errors occur, of if there is a timeout. - Returns a ``(dns.message.Message, float)`` tuple of the received message - and the received time. + If *destination* is not ``None``, returns a ``(dns.message.Message, float)`` + tuple of the received message and the received time. + + If *destination* is ``None``, returns a + ``(dns.message.Message, float, tuple)`` + tuple of the received message, the received time, and the address where + the message arrived from. """ wire = b'' while 1: _wait_for_readable(sock, expiration) (wire, from_address) = sock.recvfrom(65535) - if _addresses_equal(sock.family, from_address, destination) or \ - (dns.inet.is_multicast(destination[0]) and - from_address[1:] == destination[1:]): + if _matches_destination(sock.family, from_address, destination, + ignore_unexpected): break - if not ignore_unexpected: - raise UnexpectedSource('got a response from ' - '%s instead of %s' % (from_address, - destination)) received_time = time.time() r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, one_rr_per_rrset=one_rr_per_rrset, ignore_trailing=ignore_trailing, raise_on_truncation=raise_on_truncation) - return (r, received_time) + if destination: + return (r, received_time) + else: + return (r, received_time, from_address) def udp(q, where, timeout=None, port=53, source=None, source_port=0, ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False, diff --git a/tests/test_async.py b/tests/test_async.py index 2d254340..5faaa6e9 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -343,6 +343,25 @@ class AsyncTests(unittest.TestCase): (_, tcp) = self.async_run(run) self.assertFalse(tcp) + def testUDPReceiveQuery(self): + async def run(): + async with await self.backend.make_socket( + socket.AF_INET, socket.SOCK_DGRAM, + source=('127.0.0.1', 0)) as listener: + listener_address = await listener.getsockname() + async with await self.backend.make_socket( + socket.AF_INET, socket.SOCK_DGRAM, + source=('127.0.0.1', 0)) as sender: + sender_address = await sender.getsockname() + q = dns.message.make_query('dns.google', dns.rdatatype.A) + await dns.asyncquery.send_udp(sender, q, listener_address) + expiration = time.time() + 2 + (_, _, recv_address) = await dns.asyncquery.receive_udp( + listener, expiration=expiration) + return (sender_address, recv_address) + (sender_address, recv_address) = self.async_run(run) + self.assertEqual(sender_address, recv_address) + def testUDPReceiveTimeout(self): async def arun(): async with await self.backend.make_socket(socket.AF_INET, diff --git a/tests/test_query.py b/tests/test_query.py index f1ec55ca..498128d2 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -191,6 +191,18 @@ class QueryTests(unittest.TestCase): (_, tcp) = dns.query.udp_with_fallback(q, address) self.assertFalse(tcp) + def testUDPReceiveQuery(self): + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as listener: + listener.bind(('127.0.0.1', 0)) + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sender: + sender.bind(('127.0.0.1', 0)) + q = dns.message.make_query('dns.google', dns.rdatatype.A) + dns.query.send_udp(sender, q, listener.getsockname()) + expiration = time.time() + 2 + (q, _, addr) = dns.query.receive_udp(listener, + expiration=expiration) + self.assertEqual(addr, sender.getsockname()) + # for brevity _d_and_s = dns.query._destination_and_source