]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Use the selectors module. 538/head
authorBrian Wellington <bwelling@xbill.org>
Fri, 17 Jul 2020 22:46:04 +0000 (15:46 -0700)
committerBrian Wellington <bwelling@xbill.org>
Fri, 17 Jul 2020 22:46:04 +0000 (15:46 -0700)
Previously, there was code to either use select.select or select.poll,
depending on OS.  This changes it to use the selectors module, using
either SelectSelector or PollSelector, but sharing code otherwise.

dns/query.py
tests/test_query.py
tests/test_resolver.py

index 7df565d851f2c07aeb8babdbe5195353c8d529a8..eb82771564ce0a1f564700862945e570b5b2c846 100644 (file)
@@ -20,7 +20,7 @@
 import contextlib
 import errno
 import os
-import select
+import selectors
 import socket
 import struct
 import time
@@ -94,91 +94,46 @@ def _compute_times(timeout):
     else:
         return (now, now + timeout)
 
-# This module can use either poll() or select() as the "polling backend".
-#
-# A backend function takes an fd, bools for readability, writablity, and
-# error detection, and a timeout.
-
-def _poll_for(fd, readable, writable, error, timeout):
-    """Poll polling backend."""
-
-    event_mask = 0
-    if readable:
-        event_mask |= select.POLLIN
-    if writable:
-        event_mask |= select.POLLOUT
-    if error:
-        event_mask |= select.POLLERR
-
-    pollable = select.poll()
-    pollable.register(fd, event_mask)
-
-    if timeout:
-        event_list = pollable.poll(timeout * 1000)
-    else:
-        event_list = pollable.poll()
-
-    return bool(event_list)
-
-
-def _select_for(fd, readable, writable, error, timeout):
-    """Select polling backend."""
-
-    rset, wset, xset = [], [], []
-
-    if readable:
-        rset = [fd]
-    if writable:
-        wset = [fd]
-    if error:
-        xset = [fd]
-
-    if timeout is None:
-        (rcount, wcount, xcount) = select.select(rset, wset, xset)
-    else:
-        (rcount, wcount, xcount) = select.select(rset, wset, xset, timeout)
-
-    return bool((rcount or wcount or xcount))
-
 
 def _wait_for(fd, readable, writable, error, expiration):
-    # Use the selected polling backend to wait for any of the specified
+    # Use the selected selector class to wait for any of the specified
     # events.  An "expiration" absolute time is converted into a relative
     # timeout.
 
-    done = False
-    while not done:
-        if expiration is None:
-            timeout = None
-        else:
-            timeout = expiration - time.time()
-            if timeout <= 0.0:
-                raise dns.exception.Timeout
-        try:
-            if isinstance(fd, ssl.SSLSocket) and readable and fd.pending() > 0:
-                return True
-            if not _polling_backend(fd, readable, writable, error, timeout):
-                raise dns.exception.Timeout
-        except OSError as e:  # pragma: no cover
-            if e.args[0] != errno.EINTR:
-                raise e
-        done = True
+    if readable and isinstance(fd, ssl.SSLSocket) and fd.pending() > 0:
+        return True
+    sel = _selector_class()
+    events = 0
+    if readable:
+        events |= selectors.EVENT_READ
+    if writable:
+        events |= selectors.EVENT_WRITE
+    if events:
+        sel.register(fd, events)
+    if expiration is None:
+        timeout = None
+    else:
+        timeout = expiration - time.time()
+        if timeout <= 0.0:
+            raise dns.exception.Timeout
+    if not sel.select(timeout):
+        raise dns.exception.Timeout
 
 
-def _set_polling_backend(fn):
+def _set_selector_class(selector_class):
     # Internal API. Do not use.
 
-    global _polling_backend
+    global _selector_class
 
-    _polling_backend = fn
+    _selector_class = selector_class
 
-if hasattr(select, 'poll'):
+if hasattr(selectors, 'PollSelector'):
     # Prefer poll() on platforms that support it because it has no
     # limits on the maximum value of a file descriptor (plus it will
     # be more efficient for high values).
-    _polling_backend = _poll_for
+    _selector_class = selectors.PollSelector
 else:
-    _polling_backend = _select_for  # pragma: no cover
+    _selector_class = selectors.SelectSelector # pragma: no cover
 
 
 def _wait_for_readable(s, expiration):
index 498128d2f9bb016a319c1b64e4ba6185a8cc73af..a13833ef12140e3143af1dde4007555dfe925db7 100644 (file)
@@ -540,17 +540,3 @@ class LowLevelWaitTests(unittest.TestCase):
         finally:
             l.close()
             r.close()
-
-    def test_select_for(self):
-        # we test this explicitly in case _wait_for didn't test it (i.e.
-        # if the default polling backing is _poll_for)
-        try:
-            (l, r) = socket.socketpair()
-            # simple timeout
-            self.assertFalse(dns.query._select_for(l, False, False, False,
-                                                   0.05))
-            # writable no timeout
-            self.assertTrue(dns.query._select_for(l, False, True, False, None))
-        finally:
-            l.close()
-            r.close()
index a6ab473762fc5f261f51a5642a893a0519298e42..cadf2245c1f0592baf5ee983641b3e3dd7769607 100644 (file)
@@ -16,7 +16,7 @@
 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 
 from io import StringIO
-import select
+import selectors
 import sys
 import socket
 import time
@@ -477,26 +477,26 @@ class LiveResolverTests(unittest.TestCase):
 
 class PollingMonkeyPatchMixin(object):
     def setUp(self):
-        self.__native_polling_backend = dns.query._polling_backend
-        dns.query._set_polling_backend(self.polling_backend())
+        self.__native_selector_class = dns.query._selector_class
+        dns.query._set_selector_class(self.selector_class())
 
         unittest.TestCase.setUp(self)
 
     def tearDown(self):
-        dns.query._set_polling_backend(self.__native_polling_backend)
+        dns.query._set_selector_class(self.__native_selector_class)
 
         unittest.TestCase.tearDown(self)
 
 
 class SelectResolverTestCase(PollingMonkeyPatchMixin, LiveResolverTests, unittest.TestCase):
-    def polling_backend(self):
-        return dns.query._select_for
+    def selector_class(self):
+        return selectors.SelectSelector
 
 
-if hasattr(select, 'poll'):
+if hasattr(selectors, 'PollSelector'):
     class PollResolverTestCase(PollingMonkeyPatchMixin, LiveResolverTests, unittest.TestCase):
-        def polling_backend(self):
-            return dns.query._poll_for
+        def selector_class(self):
+            return selectors.PollSelector
 
 
 class NXDOMAINExceptionTestCase(unittest.TestCase):