]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Factor out core send and receive functionalty from
authorBob Halley <halley@dnspython.org>
Fri, 30 Sep 2016 22:20:08 +0000 (15:20 -0700)
committerBob Halley <halley@dnspython.org>
Fri, 30 Sep 2016 22:20:08 +0000 (15:20 -0700)
dns.query.udp() and dns.query.tcp(), helping
applications that want more control over the
socket.

ChangeLog
dns/query.py

index 91351887a514f0bc00c7267024f1abd793e2aae3..16f44d058a55f137e521e9e48e3e672b30bf5669 100644 (file)
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,7 +1,19 @@
-2016-09-29  Bob Halley  <halley@dnspython.org>
+2016-09-30  Bob Halley  <halley@dnspython.org>
+
+       * Much of the internals of dns.query.udp() and dns.query.tcp()
+         have been factored out into dns.query.send_udp(),
+         dns.query.receive_udp(), dns.query.send_tcp(), and
+         dns.query.receive_tcp().  Applications which want more control
+         over the socket may find the new routines helpful; for example
+         it would be easy to send multiple queries over a single TCP
+         connection.
+
+2016-09-30  Bob Halley  <halley@dnspython.org>
 
        * (Version 1.15.0 released)
 
+2016-09-29  Bob Halley  <halley@dnspython.org>
+
        * IDNA 2008 support is now available if the "idna" module has been
          installed and IDNA 2008 is requested.  The default IDNA behavior
          is still IDNA 2003.  The new IDNA codec mechanism is currently
index bfecd43e50eb5723329cad28b073d80688087ed9..2ae1d89c93da457064d2ddda65cea9f0078ce53d 100644 (file)
@@ -193,6 +193,72 @@ def _destination_and_source(af, where, port, source, source_port):
     return (af, destination, source)
 
 
+def send_udp(sock, what, destination, expiration=None):
+    """Send a DNS message to the specified UDP socket.
+
+    @param sock: the socket
+    @type sock: socket.socket
+    @param what: the message to send
+    @type what: wire format bytes or a dns.message.Message
+    @param destination: where to send the query
+    @type destination: tuple appropriate for the address family of the socket
+    @param expiration: The absolute time at which a timeout exception should
+    be raised.
+    @type expiration: float
+    @rtype: (int, double) tuple of bytes sent and the sent time.
+    """
+    if isinstance(what, dns.message.Message):
+        what = what.to_wire()
+    _wait_for_writable(sock, expiration)
+    sent_time = time.time()
+    n = sock.sendto(what, destination)
+    return (n, sent_time)
+
+
+def receive_udp(sock, destination, expiration=None, af=None,
+                ignore_unexpected=False, one_rr_per_rrset=False,
+                keyring=None, request_mac=b''):
+    """Read a DNS message from a UDP socket.
+
+    @param sock: the socket
+    @type sock: socket.socket
+    @param af: the address family to use.  The default is None, which
+    causes the address family to use to be inferred from the form of
+    destination.  If the inference attempt fails, AF_INET is used.
+    @type af: int
+    @param ignore_unexpected: If True, ignore responses from unexpected
+    sources.  The default is False.
+    @type ignore_unexpected: bool
+    @param one_rr_per_rrset: Put each RR into its own RRset
+    @type one_rr_per_rrset: bool
+    @param keyring: the keyring to use for TSIG
+    @type keyring: keyring dict
+    @param request_mac: the MAC of the request (for TSIG)
+    @type request_mac: bytes
+    @rtype: dns.message.Message object
+    """
+    if af is None:
+        try:
+            af = dns.inet.af_for_address(destination[0])
+        except Exception:
+            af = dns.inet.AF_INET
+    wire = b''
+    while 1:
+        _wait_for_readable(sock, expiration)
+        (wire, from_address) = sock.recvfrom(65535)
+        if _addresses_equal(af, from_address, destination) or \
+           (dns.inet.is_multicast(destination[0]) and
+            from_address[1:] == destination[1:]):
+            break
+        if not ignore_unexpected:
+            raise UnexpectedSource('got a response from '
+                                   '%s instead of %s' % (from_address,
+                                                         destination))
+    received_time = time.time()
+    r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
+                              one_rr_per_rrset=one_rr_per_rrset)
+    return (r, received_time)
+
 def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
         ignore_unexpected=False, one_rr_per_rrset=False):
     """Return the response obtained after sending a query via UDP.
@@ -210,7 +276,6 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
     causes the address family to use to be inferred from the form of where.
     If the inference attempt fails, AF_INET is used.
     @type af: int
-    @rtype: dns.message.Message object
     @param source: source address.  The default is the wildcard address.
     @type source: string
     @param source_port: The port from which to send the message.
@@ -221,40 +286,30 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
     @type ignore_unexpected: bool
     @param one_rr_per_rrset: Put each RR into its own RRset
     @type one_rr_per_rrset: bool
+    @rtype: dns.message.Message object
     """
 
     wire = q.to_wire()
     (af, destination, source) = _destination_and_source(af, where, port,
                                                         source, source_port)
     s = socket_factory(af, socket.SOCK_DGRAM, 0)
-    begin_time = None
+    received_time = None
+    sent_time = None
     try:
         expiration = _compute_expiration(timeout)
         s.setblocking(0)
         if source is not None:
             s.bind(source)
-        _wait_for_writable(s, expiration)
-        begin_time = time.time()
-        s.sendto(wire, destination)
-        while 1:
-            _wait_for_readable(s, expiration)
-            (wire, from_address) = s.recvfrom(65535)
-            if _addresses_equal(af, from_address, destination) or \
-                    (dns.inet.is_multicast(where) and
-                     from_address[1:] == destination[1:]):
-                break
-            if not ignore_unexpected:
-                raise UnexpectedSource('got a response from '
-                                       '%s instead of %s' % (from_address,
-                                                             destination))
+        (_, sent_time) = send_udp(s, wire, destination, expiration)
+        (r, received_time) = receive_udp(s, destination, expiration, af,
+                                         ignore_unexpected, one_rr_per_rrset,
+                                         q.keyring, q.request_mac)
     finally:
-        if begin_time is None:
+        if sent_time is None or received_time is None:
             response_time = 0
         else:
-            response_time = time.time() - begin_time
+            response_time = received_time - sent_time
         s.close()
-    r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac,
-                              one_rr_per_rrset=one_rr_per_rrset)
     r.time = response_time
     if not q.is_response(r):
         raise BadResponse
@@ -290,6 +345,52 @@ def _net_write(sock, data, expiration):
         current += sock.send(data[current:])
 
 
+def send_tcp(sock, what, expiration=None):
+    """Send a DNS message to the specified TCP socket.
+
+    @param sock: the socket
+    @type sock: socket.socket
+    @param what: the message to send
+    @type what: wire format bytes or a dns.message.Message
+    @param expiration: The absolute time at which a timeout exception should
+    be raised.
+    @type expiration: float
+    @rtype: (int, double) tuple of bytes sent and the sent time.
+    """
+    if isinstance(what, dns.message.Message):
+        what = what.to_wire()
+    l = len(what)
+    # copying the wire into tcpmsg is inefficient, but lets us
+    # avoid writev() or doing a short write that would get pushed
+    # onto the net
+    tcpmsg = struct.pack("!H", l) + what
+    _wait_for_writable(sock, expiration)
+    sent_time = time.time()
+    _net_write(sock, tcpmsg, expiration)
+    return (len(tcpmsg), sent_time)
+
+def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
+                keyring=None, request_mac=b''):
+    """Read a DNS message from a TCP socket.
+
+    @param sock: the socket
+    @type sock: socket.socket
+    @param one_rr_per_rrset: Put each RR into its own RRset
+    @type one_rr_per_rrset: bool
+    @param keyring: the keyring to use for TSIG
+    @type keyring: keyring dict
+    @param request_mac: the MAC of the request (for TSIG)
+    @type request_mac: bytes
+    @rtype: dns.message.Message object
+    """
+    ldata = _net_read(sock, 2, expiration)
+    (l,) = struct.unpack("!H", ldata)
+    wire = _net_read(sock, l, expiration)
+    received_time = time.time()
+    r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
+                              one_rr_per_rrset=one_rr_per_rrset)
+    return (r, received_time)
+
 def _connect(s, address):
     try:
         s.connect(address)
@@ -343,25 +444,15 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
         if source is not None:
             s.bind(source)
         _connect(s, destination)
-
-        l = len(wire)
-
-        # copying the wire into tcpmsg is inefficient, but lets us
-        # avoid writev() or doing a short write that would get pushed
-        # onto the net
-        tcpmsg = struct.pack("!H", l) + wire
-        _net_write(s, tcpmsg, expiration)
-        ldata = _net_read(s, 2, expiration)
-        (l,) = struct.unpack("!H", ldata)
-        wire = _net_read(s, l, expiration)
+        send_tcp(s, wire, expiration)
+        (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset,
+                                         q.keyring, q.request_mac)
     finally:
-        if begin_time is None:
+        if begin_time is None or received_time is None:
             response_time = 0
         else:
-            response_time = time.time() - begin_time
+            response_time = received_time - begin_time
         s.close()
-    r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac,
-                              one_rr_per_rrset=one_rr_per_rrset)
     r.time = response_time
     if not q.is_response(r):
         raise BadResponse