]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
fix recvfrom, tls timing, and other misc things
authorBob Halley <halley@dnspython.org>
Fri, 12 Jun 2020 20:49:27 +0000 (13:49 -0700)
committerBob Halley <halley@dnspython.org>
Fri, 12 Jun 2020 20:49:27 +0000 (13:49 -0700)
dns/_asyncio_backend.py
dns/_curio_backend.py
dns/_trio_backend.py
dns/asyncquery.py

index 5f14c4a38092e3dba570e8ef6f8db122509ac836..f82eb823916363091f3d53e7f206e6a7934494ce 100644 (file)
@@ -61,7 +61,8 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
         # no timeout for asyncio sendto
         self.transport.sendto(what, destination)
 
-    async def recvfrom(self, timeout):
+    async def recvfrom(self, size, timeout):
+        # ignore size as there's no way I know to tell protocol about it
         done = _get_running_loop().create_future()
         assert self.protocol.recvfrom is None
         self.protocol.recvfrom = done
index 5f6877edcd9215278862003b1f12dc8997578966..699276d344bf0827d340b1e41e87561f73cf3e99 100644 (file)
@@ -31,9 +31,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
             return await self.socket.sendto(what, destination)
         raise dns.exception.Timeout(timeout=timeout)
 
-    async def recvfrom(self, timeout):
+    async def recvfrom(self, size, timeout):
         async with _maybe_timeout(timeout):
-            return await self.socket.recvfrom(65535)
+            return await self.socket.recvfrom(size)
         raise dns.exception.Timeout(timeout=timeout)
 
     async def close(self):
@@ -53,9 +53,9 @@ class StreamSocket(dns._asyncbackend.DatagramSocket):
             return await self.socket.sendall(what)
         raise dns.exception.Timeout(timeout=timeout)
 
-    async def recv(self, count, timeout):
+    async def recv(self, size, timeout):
         async with _maybe_timeout(timeout):
-            return await self.socket.recv(count)
+            return await self.socket.recv(size)
         raise dns.exception.Timeout(timeout=timeout)
 
     async def close(self):
index d6a9387300958b05a5ee96dc89ba5737e24251fd..049151117532923bc33713189863acf45d394b4a 100644 (file)
@@ -31,9 +31,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
             return await self.socket.sendto(what, destination)
         raise dns.exception.Timeout(timeout=timeout)
 
-    async def recvfrom(self, timeout):
+    async def recvfrom(self, size, timeout):
         with _maybe_timeout(timeout):
-            return await self.socket.recvfrom(65535)
+            return await self.socket.recvfrom(size)
         raise dns.exception.Timeout(timeout=timeout)
 
     async def close(self):
@@ -54,9 +54,9 @@ class StreamSocket(dns._asyncbackend.DatagramSocket):
             return await self.stream.send_all(what)
         raise dns.exception.Timeout(timeout=timeout)
 
-    async def recv(self, count, timeout):
+    async def recv(self, size, timeout):
         with _maybe_timeout(timeout):
-            return await self.stream.receive_some(count)
+            return await self.stream.receive_some(size)
         raise dns.exception.Timeout(timeout=timeout)
 
     async def close(self):
@@ -94,7 +94,6 @@ class Backend(dns._asyncbackend.Backend):
             s = None
             tls = False
             if ssl_context:
-                print('TLS')
                 tls = True
                 try:
                     stream = trio.SSLStream(stream, ssl_context,
@@ -102,7 +101,6 @@ class Backend(dns._asyncbackend.Backend):
                 except Exception:
                     await stream.aclose()
                     raise
-            print('SOCKET')
             return StreamSocket(af, stream, tls)
         raise NotImplementedError(f'unsupported socket type {socktype}')
 
index 47a4ff0610c7a62d1ab91d7f71c92ac933ae6d95..d1c17933c97b5a3a48d8b34cbc0eb4b9ad8cb9d4 100644 (file)
@@ -126,7 +126,7 @@ async def receive_udp(sock, destination, expiration=None,
 
     wire = b''
     while 1:
-        (wire, from_address) = await sock.recvfrom(65535)
+        (wire, from_address) = await sock.recvfrom(65535, _timeout(expiration))
         if _addresses_equal(sock.family, from_address, destination) or \
            (dns.inet.is_multicast(destination[0]) and
             from_address[1:] == destination[1:]):
@@ -179,7 +179,7 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
     *sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
     the socket to use for the query.  If ``None``, the default, a
     socket is created.  Note that if a socket is provided, the
-    *source* and *source_port* are ignored.
+    *source*, *source_port*, and *backend* are ignored.
 
     *backend*, a ``dns.asyncbackend.Backend``, or ``None``.  If ``None``,
     the default, then dnspython will use the default backend.
@@ -248,13 +248,13 @@ async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
 
     *udp_sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
     the socket to use for the UDP query.  If ``None``, the default, a
-    socket is created.  Note that if a socket is provided the *source*
-    and *source_port* are ignored for the UDP query.
+    socket is created.  Note that if a socket is provided the *source*,
+    *source_port*, and *backend* are ignored for the UDP query.
 
     *tcp_sock*, a ``dns.asyncbackend.StreamSocket``, or ``None``, the
     socket to use for the TCP query.  If ``None``, the default, a
     socket is created.  Note that if a socket is provided *where*,
-    *source* and *source_port* are ignored for the TCP query.
+    *source*, *source_port*, and *backend*  are ignored for the TCP query.
 
     *backend*, a ``dns.asyncbackend.Backend``, or ``None``.  If ``None``,
     the default, then dnspython will use the default backend.
@@ -380,7 +380,7 @@ async def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
     *sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the
     socket to use for the query.  If ``None``, the default, a socket
     is created.  Note that if a socket is provided
-    *where*, *port*, *source* and *source_port* are ignored.
+    *where*, *port*, *source*, *source_port*, and *backend* are ignored.
 
     *backend*, a ``dns.asyncbackend.Backend``, or ``None``.  If ``None``,
     the default, then dnspython will use the default backend.
@@ -452,7 +452,8 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
     to use for the query.  If ``None``, the default, a socket is
     created.  Note that if a socket is provided, it must be a
     connected SSL stream socket, and *where*, *port*,
-    *source*, *source_port*, and *ssl_context* are ignored.
+    *source*, *source_port*, *backend*, *ssl_context*, and *server_hostname*
+    are ignored.
 
     *backend*, a ``dns.asyncbackend.Backend``, or ``None``.  If ``None``,
     the default, then dnspython will use the default backend.
@@ -469,6 +470,7 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
     """
     if not backend:
         backend = dns.asyncbackend.get_default_backend()
+    (begin_time, expiration) = _compute_times(timeout)
     if not sock:
         if ssl_context is None:
             ssl_context = ssl.create_default_context()
@@ -489,8 +491,12 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
         #
         # If a socket was provided, there's no special TLS handling needed.
         #
-        return await tcp(q, where, timeout, port, source, source_port,
-                         one_rr_per_rrset, ignore_trailing, s, backend)
+        timeout = _timeout(expiration)
+        response = await tcp(q, where, timeout, port, source, source_port,
+                             one_rr_per_rrset, ignore_trailing, s, backend)
+        end_time = time.time()
+        response.time = end_time - begin_time
+        return response
     finally:
         if not sock and s:
             await s.close()