From: Brian Wellington Date: Fri, 17 Jul 2020 22:46:04 +0000 (-0700) Subject: Use the selectors module. X-Git-Tag: v2.1.0rc1~177^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=147924d0a433968c639f75630009eff8a872a4d3;p=thirdparty%2Fdnspython.git Use the selectors module. Previously, there was code to either use select.select or select.poll, depending on OS. This changes it to use the selectors module, using either SelectSelector or PollSelector, but sharing code otherwise. --- diff --git a/dns/query.py b/dns/query.py index 7df565d8..eb827715 100644 --- a/dns/query.py +++ b/dns/query.py @@ -20,7 +20,7 @@ import contextlib import errno import os -import select +import selectors import socket import struct import time @@ -94,91 +94,46 @@ def _compute_times(timeout): else: return (now, now + timeout) -# This module can use either poll() or select() as the "polling backend". -# -# A backend function takes an fd, bools for readability, writablity, and -# error detection, and a timeout. - -def _poll_for(fd, readable, writable, error, timeout): - """Poll polling backend.""" - - event_mask = 0 - if readable: - event_mask |= select.POLLIN - if writable: - event_mask |= select.POLLOUT - if error: - event_mask |= select.POLLERR - - pollable = select.poll() - pollable.register(fd, event_mask) - - if timeout: - event_list = pollable.poll(timeout * 1000) - else: - event_list = pollable.poll() - - return bool(event_list) - - -def _select_for(fd, readable, writable, error, timeout): - """Select polling backend.""" - - rset, wset, xset = [], [], [] - - if readable: - rset = [fd] - if writable: - wset = [fd] - if error: - xset = [fd] - - if timeout is None: - (rcount, wcount, xcount) = select.select(rset, wset, xset) - else: - (rcount, wcount, xcount) = select.select(rset, wset, xset, timeout) - - return bool((rcount or wcount or xcount)) - def _wait_for(fd, readable, writable, error, expiration): - # Use the selected polling backend to wait for any of the specified + # Use the selected selector class to wait for any of the specified # events. An "expiration" absolute time is converted into a relative # timeout. - done = False - while not done: - if expiration is None: - timeout = None - else: - timeout = expiration - time.time() - if timeout <= 0.0: - raise dns.exception.Timeout - try: - if isinstance(fd, ssl.SSLSocket) and readable and fd.pending() > 0: - return True - if not _polling_backend(fd, readable, writable, error, timeout): - raise dns.exception.Timeout - except OSError as e: # pragma: no cover - if e.args[0] != errno.EINTR: - raise e - done = True + if readable and isinstance(fd, ssl.SSLSocket) and fd.pending() > 0: + return True + sel = _selector_class() + events = 0 + if readable: + events |= selectors.EVENT_READ + if writable: + events |= selectors.EVENT_WRITE + if events: + sel.register(fd, events) + if expiration is None: + timeout = None + else: + timeout = expiration - time.time() + if timeout <= 0.0: + raise dns.exception.Timeout + if not sel.select(timeout): + raise dns.exception.Timeout -def _set_polling_backend(fn): +def _set_selector_class(selector_class): # Internal API. Do not use. - global _polling_backend + global _selector_class - _polling_backend = fn + _selector_class = selector_class -if hasattr(select, 'poll'): +if hasattr(selectors, 'PollSelector'): # Prefer poll() on platforms that support it because it has no # limits on the maximum value of a file descriptor (plus it will # be more efficient for high values). - _polling_backend = _poll_for + _selector_class = selectors.PollSelector else: - _polling_backend = _select_for # pragma: no cover + _selector_class = selectors.SelectSelector # pragma: no cover def _wait_for_readable(s, expiration): diff --git a/tests/test_query.py b/tests/test_query.py index 498128d2..a13833ef 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -540,17 +540,3 @@ class LowLevelWaitTests(unittest.TestCase): finally: l.close() r.close() - - def test_select_for(self): - # we test this explicitly in case _wait_for didn't test it (i.e. - # if the default polling backing is _poll_for) - try: - (l, r) = socket.socketpair() - # simple timeout - self.assertFalse(dns.query._select_for(l, False, False, False, - 0.05)) - # writable no timeout - self.assertTrue(dns.query._select_for(l, False, True, False, None)) - finally: - l.close() - r.close() diff --git a/tests/test_resolver.py b/tests/test_resolver.py index a6ab4737..cadf2245 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -16,7 +16,7 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. from io import StringIO -import select +import selectors import sys import socket import time @@ -477,26 +477,26 @@ class LiveResolverTests(unittest.TestCase): class PollingMonkeyPatchMixin(object): def setUp(self): - self.__native_polling_backend = dns.query._polling_backend - dns.query._set_polling_backend(self.polling_backend()) + self.__native_selector_class = dns.query._selector_class + dns.query._set_selector_class(self.selector_class()) unittest.TestCase.setUp(self) def tearDown(self): - dns.query._set_polling_backend(self.__native_polling_backend) + dns.query._set_selector_class(self.__native_selector_class) unittest.TestCase.tearDown(self) class SelectResolverTestCase(PollingMonkeyPatchMixin, LiveResolverTests, unittest.TestCase): - def polling_backend(self): - return dns.query._select_for + def selector_class(self): + return selectors.SelectSelector -if hasattr(select, 'poll'): +if hasattr(selectors, 'PollSelector'): class PollResolverTestCase(PollingMonkeyPatchMixin, LiveResolverTests, unittest.TestCase): - def polling_backend(self): - return dns.query._poll_for + def selector_class(self): + return selectors.PollSelector class NXDOMAINExceptionTestCase(unittest.TestCase):