import contextlib
import errno
import os
-import select
+import selectors
import socket
import struct
import time
else:
return (now, now + timeout)
-# This module can use either poll() or select() as the "polling backend".
-#
-# A backend function takes an fd, bools for readability, writablity, and
-# error detection, and a timeout.
-
-def _poll_for(fd, readable, writable, error, timeout):
- """Poll polling backend."""
-
- event_mask = 0
- if readable:
- event_mask |= select.POLLIN
- if writable:
- event_mask |= select.POLLOUT
- if error:
- event_mask |= select.POLLERR
-
- pollable = select.poll()
- pollable.register(fd, event_mask)
-
- if timeout:
- event_list = pollable.poll(timeout * 1000)
- else:
- event_list = pollable.poll()
-
- return bool(event_list)
-
-
-def _select_for(fd, readable, writable, error, timeout):
- """Select polling backend."""
-
- rset, wset, xset = [], [], []
-
- if readable:
- rset = [fd]
- if writable:
- wset = [fd]
- if error:
- xset = [fd]
-
- if timeout is None:
- (rcount, wcount, xcount) = select.select(rset, wset, xset)
- else:
- (rcount, wcount, xcount) = select.select(rset, wset, xset, timeout)
-
- return bool((rcount or wcount or xcount))
-
def _wait_for(fd, readable, writable, error, expiration):
- # Use the selected polling backend to wait for any of the specified
+ # Use the selected selector class to wait for any of the specified
# events. An "expiration" absolute time is converted into a relative
# timeout.
- done = False
- while not done:
- if expiration is None:
- timeout = None
- else:
- timeout = expiration - time.time()
- if timeout <= 0.0:
- raise dns.exception.Timeout
- try:
- if isinstance(fd, ssl.SSLSocket) and readable and fd.pending() > 0:
- return True
- if not _polling_backend(fd, readable, writable, error, timeout):
- raise dns.exception.Timeout
- except OSError as e: # pragma: no cover
- if e.args[0] != errno.EINTR:
- raise e
- done = True
+ if readable and isinstance(fd, ssl.SSLSocket) and fd.pending() > 0:
+ return True
+ sel = _selector_class()
+ events = 0
+ if readable:
+ events |= selectors.EVENT_READ
+ if writable:
+ events |= selectors.EVENT_WRITE
+ if events:
+ sel.register(fd, events)
+ if expiration is None:
+ timeout = None
+ else:
+ timeout = expiration - time.time()
+ if timeout <= 0.0:
+ raise dns.exception.Timeout
+ if not sel.select(timeout):
+ raise dns.exception.Timeout
-def _set_polling_backend(fn):
+def _set_selector_class(selector_class):
# Internal API. Do not use.
- global _polling_backend
+ global _selector_class
- _polling_backend = fn
+ _selector_class = selector_class
-if hasattr(select, 'poll'):
+if hasattr(selectors, 'PollSelector'):
# Prefer poll() on platforms that support it because it has no
# limits on the maximum value of a file descriptor (plus it will
# be more efficient for high values).
- _polling_backend = _poll_for
+ _selector_class = selectors.PollSelector
else:
- _polling_backend = _select_for # pragma: no cover
+ _selector_class = selectors.SelectSelector # pragma: no cover
def _wait_for_readable(s, expiration):
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from io import StringIO
-import select
+import selectors
import sys
import socket
import time
class PollingMonkeyPatchMixin(object):
def setUp(self):
- self.__native_polling_backend = dns.query._polling_backend
- dns.query._set_polling_backend(self.polling_backend())
+ self.__native_selector_class = dns.query._selector_class
+ dns.query._set_selector_class(self.selector_class())
unittest.TestCase.setUp(self)
def tearDown(self):
- dns.query._set_polling_backend(self.__native_polling_backend)
+ dns.query._set_selector_class(self.__native_selector_class)
unittest.TestCase.tearDown(self)
class SelectResolverTestCase(PollingMonkeyPatchMixin, LiveResolverTests, unittest.TestCase):
- def polling_backend(self):
- return dns.query._select_for
+ def selector_class(self):
+ return selectors.SelectSelector
-if hasattr(select, 'poll'):
+if hasattr(selectors, 'PollSelector'):
class PollResolverTestCase(PollingMonkeyPatchMixin, LiveResolverTests, unittest.TestCase):
- def polling_backend(self):
- return dns.query._poll_for
+ def selector_class(self):
+ return selectors.PollSelector
class NXDOMAINExceptionTestCase(unittest.TestCase):