-# Copyright (C) Internet Systems Consortium, Inc. ("ISC")
-#
-# SPDX-License-Identifier: MPL-2.0
-#
-# This Source Code Form is subject to the terms of the Mozilla Public
-# License, v. 2.0. If a copy of the MPL was not distributed with this
-# file, you can obtain one at https://mozilla.org/MPL/2.0/.
-#
-# See the COPYRIGHT file distributed with this work for additional
-# information regarding copyright ownership.
+"""
+Copyright (C) Internet Systems Consortium, Inc. ("ISC")
-############################################################################
-#
-# This tool acts as a TCP/UDP proxy and delays all incoming packets by 500
-# milliseconds.
-#
-# We use it to check pipelining - a client sents 8 questions over a
-# pipelined connection - that require asking a normal (examplea) and a
-# slow-responding (exampleb) servers:
-# a.examplea
-# a.exampleb
-# b.examplea
-# b.exampleb
-# c.examplea
-# c.exampleb
-# d.examplea
-# d.exampleb
-#
-# If pipelining works properly the answers will be returned out of order
-# with all answers from examplea returned first, and then all answers
-# from exampleb.
-#
-############################################################################
+SPDX-License-Identifier: MPL-2.0
-from __future__ import print_function
+This Source Code Form is subject to the terms of the Mozilla Public
+License, v. 2.0. If a copy of the MPL was not distributed with this
+file, you can obtain one at https://mozilla.org/MPL/2.0/.
-import datetime
-import os
-import select
-import signal
-import socket
-import sys
-import time
-import threading
-import struct
+See the COPYRIGHT file distributed with this work for additional
+information regarding copyright ownership.
+"""
-DELAY = 0.5
-THREADS = []
+from isctest.asyncserver import AsyncDnsServer, ForwarderHandler
-def log(msg):
- print(datetime.datetime.now().strftime("%d-%b-%Y %H:%M:%S.%f ") + msg)
+class ForwardToNs2(ForwarderHandler):
+ target = "10.53.0.2"
+ delay = 0.5
-def sigterm(*_):
- log("SIGTERM received, shutting down")
- for thread in THREADS:
- thread.close()
- thread.join()
- os.remove("ans.pid")
- sys.exit(0)
-
-
-class TCPDelayer(threading.Thread):
- """For a given TCP connection conn we open a connection to (ip, port),
- and then we delay each incoming packet by DELAY by putting it in a
- queue.
- In the pipelined test TCP should not be used, but it's here for
- completnes.
- """
-
- def __init__(self, conn, ip, port):
- threading.Thread.__init__(self)
- self.conn = conn
- self.cconn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self.cconn.connect((ip, port))
- self.queue = []
- self.running = True
-
- def close(self):
- self.running = False
-
- def run(self):
- while self.running:
- curr_timeout = 0.5
- try:
- curr_timeout = self.queue[0][0] - time.monotonic()
- except StopIteration:
- pass
- if curr_timeout > 0:
- if curr_timeout == 0:
- curr_timeout = 0.5
- rfds, _, _ = select.select(
- [self.conn, self.cconn], [], [], curr_timeout
- )
- if self.conn in rfds:
- data = self.conn.recv(65535)
- if not data:
- return
- self.queue.append((time.monotonic() + DELAY, data))
- if self.cconn in rfds:
- data = self.cconn.recv(65535)
- if not data == 0:
- return
- self.conn.send(data)
- try:
- while self.queue[0][0] - time.monotonic() < 0:
- _, data = self.queue.pop(0)
- self.cconn.send(data)
- except StopIteration:
- pass
-
-
-class UDPDelayer(threading.Thread):
- """Every incoming UDP packet is put in a queue for DELAY time, then
- it's sent to (ip, port). We remember the query id to send the
- response we get to a proper source, responses are not delayed.
- """
-
- def __init__(self, usock, ip, port):
- threading.Thread.__init__(self)
- self.sock = usock
- self.csock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- self.dst = (ip, port)
- self.queue = []
- self.qid_mapping = {}
- self.running = True
-
- def close(self):
- self.running = False
-
- def run(self):
- while self.running:
- curr_timeout = 0.5
- if self.queue:
- curr_timeout = self.queue[0][0] - time.monotonic()
- if curr_timeout >= 0:
- if curr_timeout == 0:
- curr_timeout = 0.5
- rfds, _, _ = select.select(
- [self.sock, self.csock], [], [], curr_timeout
- )
- if self.sock in rfds:
- data, addr = self.sock.recvfrom(65535)
- if not data:
- return
- self.queue.append((time.monotonic() + DELAY, data))
- qid = struct.unpack(">H", data[:2])[0]
- log("Received a query from %s, queryid %d" % (str(addr), qid))
- self.qid_mapping[qid] = addr
- if self.csock in rfds:
- data, addr = self.csock.recvfrom(65535)
- if not data:
- return
- qid = struct.unpack(">H", data[:2])[0]
- dst = self.qid_mapping.get(qid)
- if dst is not None:
- self.sock.sendto(data, dst)
- log(
- "Received a response from %s, queryid %d, sending to %s"
- % (str(addr), qid, str(dst))
- )
- while self.queue and self.queue[0][0] - time.monotonic() < 0:
- _, data = self.queue.pop(0)
- qid = struct.unpack(">H", data[:2])[0]
- log("Sending a query to %s, queryid %d" % (str(self.dst), qid))
- self.csock.sendto(data, self.dst)
-
-
-def main():
- signal.signal(signal.SIGTERM, sigterm)
- signal.signal(signal.SIGINT, sigterm)
-
- with open("ans.pid", "w") as pidfile:
- print(os.getpid(), file=pidfile)
-
- listenip = "10.53.0.5"
- serverip = "10.53.0.2"
-
- try:
- port = int(os.environ["PORT"])
- except KeyError:
- port = 5300
-
- log("Listening on %s:%d" % (listenip, port))
-
- usock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- usock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- usock.bind((listenip, port))
- thread = UDPDelayer(usock, serverip, port)
- thread.start()
- THREADS.append(thread)
-
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- sock.bind((listenip, port))
- sock.listen(1)
- sock.settimeout(1)
-
- while True:
- try:
- clientsock, _ = sock.accept()
- log("Accepted connection from %s" % clientsock)
- thread = TCPDelayer(clientsock, serverip, port)
- thread.start()
- THREADS.append(thread)
- except socket.timeout:
- pass
+def main() -> None:
+ server = AsyncDnsServer()
+ server.install_response_handlers(ForwardToNs2())
+ server.run()
if __name__ == "__main__":