From ac2bd6f2b386b3ea2b486189506ceae13ad2108e Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Fri, 30 Sep 2016 15:20:08 -0700 Subject: [PATCH] Factor out core send and receive functionalty from dns.query.udp() and dns.query.tcp(), helping applications that want more control over the socket. --- ChangeLog | 14 ++++- dns/query.py | 161 ++++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 139 insertions(+), 36 deletions(-) diff --git a/ChangeLog b/ChangeLog index 91351887..16f44d05 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,7 +1,19 @@ -2016-09-29 Bob Halley +2016-09-30 Bob Halley + + * Much of the internals of dns.query.udp() and dns.query.tcp() + have been factored out into dns.query.send_udp(), + dns.query.receive_udp(), dns.query.send_tcp(), and + dns.query.receive_tcp(). Applications which want more control + over the socket may find the new routines helpful; for example + it would be easy to send multiple queries over a single TCP + connection. + +2016-09-30 Bob Halley * (Version 1.15.0 released) +2016-09-29 Bob Halley + * IDNA 2008 support is now available if the "idna" module has been installed and IDNA 2008 is requested. The default IDNA behavior is still IDNA 2003. The new IDNA codec mechanism is currently diff --git a/dns/query.py b/dns/query.py index bfecd43e..2ae1d89c 100644 --- a/dns/query.py +++ b/dns/query.py @@ -193,6 +193,72 @@ def _destination_and_source(af, where, port, source, source_port): return (af, destination, source) +def send_udp(sock, what, destination, expiration=None): + """Send a DNS message to the specified UDP socket. + + @param sock: the socket + @type sock: socket.socket + @param what: the message to send + @type what: wire format bytes or a dns.message.Message + @param destination: where to send the query + @type destination: tuple appropriate for the address family of the socket + @param expiration: The absolute time at which a timeout exception should + be raised. + @type expiration: float + @rtype: (int, double) tuple of bytes sent and the sent time. + """ + if isinstance(what, dns.message.Message): + what = what.to_wire() + _wait_for_writable(sock, expiration) + sent_time = time.time() + n = sock.sendto(what, destination) + return (n, sent_time) + + +def receive_udp(sock, destination, expiration=None, af=None, + ignore_unexpected=False, one_rr_per_rrset=False, + keyring=None, request_mac=b''): + """Read a DNS message from a UDP socket. + + @param sock: the socket + @type sock: socket.socket + @param af: the address family to use. The default is None, which + causes the address family to use to be inferred from the form of + destination. If the inference attempt fails, AF_INET is used. + @type af: int + @param ignore_unexpected: If True, ignore responses from unexpected + sources. The default is False. + @type ignore_unexpected: bool + @param one_rr_per_rrset: Put each RR into its own RRset + @type one_rr_per_rrset: bool + @param keyring: the keyring to use for TSIG + @type keyring: keyring dict + @param request_mac: the MAC of the request (for TSIG) + @type request_mac: bytes + @rtype: dns.message.Message object + """ + if af is None: + try: + af = dns.inet.af_for_address(destination[0]) + except Exception: + af = dns.inet.AF_INET + wire = b'' + while 1: + _wait_for_readable(sock, expiration) + (wire, from_address) = sock.recvfrom(65535) + if _addresses_equal(af, from_address, destination) or \ + (dns.inet.is_multicast(destination[0]) and + from_address[1:] == destination[1:]): + break + if not ignore_unexpected: + raise UnexpectedSource('got a response from ' + '%s instead of %s' % (from_address, + destination)) + received_time = time.time() + r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset) + return (r, received_time) + def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, ignore_unexpected=False, one_rr_per_rrset=False): """Return the response obtained after sending a query via UDP. @@ -210,7 +276,6 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, causes the address family to use to be inferred from the form of where. If the inference attempt fails, AF_INET is used. @type af: int - @rtype: dns.message.Message object @param source: source address. The default is the wildcard address. @type source: string @param source_port: The port from which to send the message. @@ -221,40 +286,30 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, @type ignore_unexpected: bool @param one_rr_per_rrset: Put each RR into its own RRset @type one_rr_per_rrset: bool + @rtype: dns.message.Message object """ wire = q.to_wire() (af, destination, source) = _destination_and_source(af, where, port, source, source_port) s = socket_factory(af, socket.SOCK_DGRAM, 0) - begin_time = None + received_time = None + sent_time = None try: expiration = _compute_expiration(timeout) s.setblocking(0) if source is not None: s.bind(source) - _wait_for_writable(s, expiration) - begin_time = time.time() - s.sendto(wire, destination) - while 1: - _wait_for_readable(s, expiration) - (wire, from_address) = s.recvfrom(65535) - if _addresses_equal(af, from_address, destination) or \ - (dns.inet.is_multicast(where) and - from_address[1:] == destination[1:]): - break - if not ignore_unexpected: - raise UnexpectedSource('got a response from ' - '%s instead of %s' % (from_address, - destination)) + (_, sent_time) = send_udp(s, wire, destination, expiration) + (r, received_time) = receive_udp(s, destination, expiration, af, + ignore_unexpected, one_rr_per_rrset, + q.keyring, q.request_mac) finally: - if begin_time is None: + if sent_time is None or received_time is None: response_time = 0 else: - response_time = time.time() - begin_time + response_time = received_time - sent_time s.close() - r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac, - one_rr_per_rrset=one_rr_per_rrset) r.time = response_time if not q.is_response(r): raise BadResponse @@ -290,6 +345,52 @@ def _net_write(sock, data, expiration): current += sock.send(data[current:]) +def send_tcp(sock, what, expiration=None): + """Send a DNS message to the specified TCP socket. + + @param sock: the socket + @type sock: socket.socket + @param what: the message to send + @type what: wire format bytes or a dns.message.Message + @param expiration: The absolute time at which a timeout exception should + be raised. + @type expiration: float + @rtype: (int, double) tuple of bytes sent and the sent time. + """ + if isinstance(what, dns.message.Message): + what = what.to_wire() + l = len(what) + # copying the wire into tcpmsg is inefficient, but lets us + # avoid writev() or doing a short write that would get pushed + # onto the net + tcpmsg = struct.pack("!H", l) + what + _wait_for_writable(sock, expiration) + sent_time = time.time() + _net_write(sock, tcpmsg, expiration) + return (len(tcpmsg), sent_time) + +def receive_tcp(sock, expiration=None, one_rr_per_rrset=False, + keyring=None, request_mac=b''): + """Read a DNS message from a TCP socket. + + @param sock: the socket + @type sock: socket.socket + @param one_rr_per_rrset: Put each RR into its own RRset + @type one_rr_per_rrset: bool + @param keyring: the keyring to use for TSIG + @type keyring: keyring dict + @param request_mac: the MAC of the request (for TSIG) + @type request_mac: bytes + @rtype: dns.message.Message object + """ + ldata = _net_read(sock, 2, expiration) + (l,) = struct.unpack("!H", ldata) + wire = _net_read(sock, l, expiration) + received_time = time.time() + r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset) + return (r, received_time) + def _connect(s, address): try: s.connect(address) @@ -343,25 +444,15 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, if source is not None: s.bind(source) _connect(s, destination) - - l = len(wire) - - # copying the wire into tcpmsg is inefficient, but lets us - # avoid writev() or doing a short write that would get pushed - # onto the net - tcpmsg = struct.pack("!H", l) + wire - _net_write(s, tcpmsg, expiration) - ldata = _net_read(s, 2, expiration) - (l,) = struct.unpack("!H", ldata) - wire = _net_read(s, l, expiration) + send_tcp(s, wire, expiration) + (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset, + q.keyring, q.request_mac) finally: - if begin_time is None: + if begin_time is None or received_time is None: response_time = 0 else: - response_time = time.time() - begin_time + response_time = received_time - begin_time s.close() - r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac, - one_rr_per_rrset=one_rr_per_rrset) r.time = response_time if not q.is_response(r): raise BadResponse -- 2.47.3