_checkConfigExpectedOutput = None
_verboseMode = False
_skipListeningOnCL = False
+ _backgroundThreads = {}
+ _UDPResponder = None
+ _TCPResponder = None
@classmethod
def startResponders(cls):
cls._dnsdist.kill()
cls._dnsdist.wait()
+ # tell the background threads to stop, if any
+ for backgroundThread in cls._backgroundThreads:
+ cls._backgroundThreads[backgroundThread] = False
+
@classmethod
def _ResponderIncrementCounter(cls):
if threading.currentThread().name in cls._responsesCounter:
@classmethod
def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, callback=None):
+ cls._backgroundThreads[threading.get_native_id()] = True
# trailingDataResponse=True means "ignore trailing data".
# Other values are either False (meaning "raise an exception")
# or are interpreted as a response RCODE for queries with trailing data.
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
sock.bind(("127.0.0.1", port))
+ sock.settimeout(1.0)
while True:
- data, addr = sock.recvfrom(4096)
+ try:
+ data, addr = sock.recvfrom(4096)
+ except socket.timeout:
+ if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
+ del cls._backgroundThreads[threading.get_native_id()]
+ break
+ else:
+ continue
+
forceRcode = None
try:
request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
if not wire:
continue
- sock.settimeout(2.0)
sock.sendto(wire, addr)
- sock.settimeout(None)
+
sock.close()
@classmethod
@classmethod
def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, multipleConnections=False, listeningAddr='127.0.0.1'):
+ cls._backgroundThreads[threading.get_native_id()] = True
# trailingDataResponse=True means "ignore trailing data".
# Other values are either False (meaning "raise an exception")
# or are interpreted as a response RCODE for queries with trailing data.
sys.exit(1)
sock.listen(100)
+ sock.settimeout(1.0)
if tlsContext:
sock = tlsContext.wrap_socket(sock, server_side=True)
continue
except ConnectionResetError:
continue
+ except socket.timeout:
+ if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
+ del cls._backgroundThreads[threading.get_native_id()]
+ break
+ else:
+ continue
conn.settimeout(5.0)
if multipleConnections:
@classmethod
def DOHResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, useProxyProtocol=False):
+ cls._backgroundThreads[threading.get_native_id()] = True
# trailingDataResponse=True means "ignore trailing data".
# Other values are either False (meaning "raise an exception")
# or are interpreted as a response RCODE for queries with trailing data.
sys.exit(1)
sock.listen(100)
+ sock.settimeout(1.0)
if tlsContext:
sock = tlsContext.wrap_socket(sock, server_side=True)
continue
except ConnectionResetError:
continue
+ except socket.timeout:
+ if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
+ del cls._backgroundThreads[threading.get_native_id()]
+ break
+ else:
+ continue
conn.settimeout(5.0)
thread = threading.Thread(name='DoH Connection Handler',
@classmethod
def BrokenTCPResponder(cls, port):
+ cls._backgroundThreads[threading.get_native_id()] = True
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
sys.exit(1)
sock.listen(100)
+ sock.settimeout(1.0)
while True:
- (conn, _) = sock.accept()
+ try:
+ (conn, _) = sock.accept()
+ except socket.timeout:
+ if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
+ del cls._backgroundThreads[threading.get_native_id()]
+ break
+ else:
+ continue
+
conn.settimeout(5.0)
data = conn.recv(2)
if not data: