async def getpeername(self):
return self.transport.get_extra_info('peername')
+ async def getsockname(self):
+ return self.transport.get_extra_info('sockname')
+
class StreamSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, af, reader, writer):
async def getpeername(self):
return self.writer.get_extra_info('peername')
+ async def getsockname(self):
+ return self.writer.get_extra_info('sockname')
+
class Backend(dns._asyncbackend.Backend):
def name(self):
async def getpeername(self):
return self.socket.getpeername()
+ async def getsockname(self):
+ return self.socket.getsockname()
+
class StreamSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, socket):
async def getpeername(self):
return self.socket.getpeername()
+ async def getsockname(self):
+ return self.socket.getsockname()
+
class Backend(dns._asyncbackend.Backend):
def name(self):
async def getpeername(self):
return self.socket.getpeername()
+ async def getsockname(self):
+ return self.socket.getsockname()
+
class StreamSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, family, stream, tls=False):
else:
return self.stream.socket.getpeername()
+ async def getsockname(self):
+ if self.tls:
+ return self.stream.transport_stream.socket.getsockname()
+ else:
+ return self.stream.socket.getsockname()
+
class Backend(dns._asyncbackend.Backend):
def name(self):
import dns.rdataclass
import dns.rdatatype
-from dns.query import _addresses_equal, _compute_times, UnexpectedSource, \
- BadResponse, ssl
+from dns.query import _compute_times, _matches_destination, BadResponse, ssl
# for brevity
return (n, sent_time)
-async def receive_udp(sock, destination, expiration=None,
+async def receive_udp(sock, destination=None, expiration=None,
ignore_unexpected=False, one_rr_per_rrset=False,
keyring=None, request_mac=b'', ignore_trailing=False,
raise_on_truncation=False):
*sock*, a ``dns.asyncbackend.DatagramSocket``.
*destination*, a destination tuple appropriate for the address family
- of the socket, specifying where the associated query was sent.
+ of the socket, specifying where the message is expected to arrive from.
+ When receiving a response, this would be where the associated query was
+ sent.
*expiration*, a ``float`` or ``None``, the absolute time at which
a timeout exception should be raised. If ``None``, no timeout will
Raises if the message is malformed, if network errors occur, of if
there is a timeout.
- Returns a ``(dns.message.Message, float)`` tuple of the received message
- and the received time.
+ Returns a ``(dns.message.Message, float, tuple)`` tuple of the received
+ message, the received time, and the address where the message arrived from.
"""
wire = b''
while 1:
(wire, from_address) = await sock.recvfrom(65535, _timeout(expiration))
- if _addresses_equal(sock.family, from_address, destination) or \
- (dns.inet.is_multicast(destination[0]) and
- from_address[1:] == destination[1:]):
+ if _matches_destination(sock.family, from_address, destination,
+ ignore_unexpected):
break
- if not ignore_unexpected:
- raise UnexpectedSource('got a response from '
- '%s instead of %s' % (from_address,
- destination))
received_time = time.time()
r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
raise_on_truncation=raise_on_truncation)
- return (r, received_time)
+ return (r, received_time, from_address)
async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
ignore_unexpected=False, one_rr_per_rrset=False,
stuple = _source_tuple(af, source, source_port)
s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple)
await send_udp(s, wire, destination, expiration)
- (r, received_time) = await receive_udp(s, destination, expiration,
- ignore_unexpected,
- one_rr_per_rrset,
- q.keyring, q.mac,
- ignore_trailing,
- raise_on_truncation)
+ (r, received_time, _) = await receive_udp(s, destination, expiration,
+ ignore_unexpected,
+ one_rr_per_rrset,
+ q.keyring, q.mac,
+ ignore_trailing,
+ raise_on_truncation)
r.time = received_time - begin_time
if not q.is_response(r):
raise BadResponse
return n1 == n2 and a1[1:] == a2[1:]
+def _matches_destination(af, from_address, destination, ignore_unexpected):
+ # Check that from_address is appropriate for a response to a query
+ # sent to destination.
+ if not destination:
+ return True
+ if _addresses_equal(af, from_address, destination) or \
+ (dns.inet.is_multicast(destination[0]) and
+ from_address[1:] == destination[1:]):
+ return True
+ elif ignore_unexpected:
+ return False
+ raise UnexpectedSource(f'got a response from {from_address} instead of '
+ f'{destination}')
+
+
def _destination_and_source(where, port, source, source_port,
where_must_be_address=True):
# Apply defaults and compute destination and source tuples
return (n, sent_time)
-def receive_udp(sock, destination, expiration=None,
+def receive_udp(sock, destination=None, expiration=None,
ignore_unexpected=False, one_rr_per_rrset=False,
keyring=None, request_mac=b'', ignore_trailing=False,
raise_on_truncation=False):
*sock*, a ``socket``.
*destination*, a destination tuple appropriate for the address family
- of the socket, specifying where the associated query was sent.
+ of the socket, specifying where the message is expected to arrive from.
+ When receiving a response, this would be where the associated query was
+ sent.
*expiration*, a ``float`` or ``None``, the absolute time at which
a timeout exception should be raised. If ``None``, no timeout will
Raises if the message is malformed, if network errors occur, of if
there is a timeout.
- Returns a ``(dns.message.Message, float)`` tuple of the received message
- and the received time.
+ If *destination* is not ``None``, returns a ``(dns.message.Message, float)``
+ tuple of the received message and the received time.
+
+ If *destination* is ``None``, returns a
+ ``(dns.message.Message, float, tuple)``
+ tuple of the received message, the received time, and the address where
+ the message arrived from.
"""
wire = b''
while 1:
_wait_for_readable(sock, expiration)
(wire, from_address) = sock.recvfrom(65535)
- if _addresses_equal(sock.family, from_address, destination) or \
- (dns.inet.is_multicast(destination[0]) and
- from_address[1:] == destination[1:]):
+ if _matches_destination(sock.family, from_address, destination,
+ ignore_unexpected):
break
- if not ignore_unexpected:
- raise UnexpectedSource('got a response from '
- '%s instead of %s' % (from_address,
- destination))
received_time = time.time()
r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
raise_on_truncation=raise_on_truncation)
- return (r, received_time)
+ if destination:
+ return (r, received_time)
+ else:
+ return (r, received_time, from_address)
def udp(q, where, timeout=None, port=53, source=None, source_port=0,
ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,
(_, tcp) = self.async_run(run)
self.assertFalse(tcp)
+ def testUDPReceiveQuery(self):
+ async def run():
+ async with await self.backend.make_socket(
+ socket.AF_INET, socket.SOCK_DGRAM,
+ source=('127.0.0.1', 0)) as listener:
+ listener_address = await listener.getsockname()
+ async with await self.backend.make_socket(
+ socket.AF_INET, socket.SOCK_DGRAM,
+ source=('127.0.0.1', 0)) as sender:
+ sender_address = await sender.getsockname()
+ q = dns.message.make_query('dns.google', dns.rdatatype.A)
+ await dns.asyncquery.send_udp(sender, q, listener_address)
+ expiration = time.time() + 2
+ (_, _, recv_address) = await dns.asyncquery.receive_udp(
+ listener, expiration=expiration)
+ return (sender_address, recv_address)
+ (sender_address, recv_address) = self.async_run(run)
+ self.assertEqual(sender_address, recv_address)
+
def testUDPReceiveTimeout(self):
async def arun():
async with await self.backend.make_socket(socket.AF_INET,
(_, tcp) = dns.query.udp_with_fallback(q, address)
self.assertFalse(tcp)
+ def testUDPReceiveQuery(self):
+ with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as listener:
+ listener.bind(('127.0.0.1', 0))
+ with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sender:
+ sender.bind(('127.0.0.1', 0))
+ q = dns.message.make_query('dns.google', dns.rdatatype.A)
+ dns.query.send_udp(sender, q, listener.getsockname())
+ expiration = time.time() + 2
+ (q, _, addr) = dns.query.receive_udp(listener,
+ expiration=expiration)
+ self.assertEqual(addr, sender.getsockname())
+
# for brevity
_d_and_s = dns.query._destination_and_source