-import SocketServer, socket, threading, struct
+import select, socket, threading, struct, sys, os
import dns.message
+def recv_message(stream):
+ """ Receive DNS/TCP message. """
+ wire_len = stream.recv(2)
+ if len(wire_len) != 2:
+ return None
+ wire_len = struct.unpack("!H", wire_len)[0]
+ return dns.message.from_wire(stream.recv(wire_len))
-class DNSHandler(SocketServer.BaseRequestHandler):
- """ This handler returns prescripted or mirror DNS responses. """
-
- def handle(self):
- """ Handle incoming queries. """
- wire_len = self.request.recv(2)
- if len(wire_len) != 2:
- return
- wire_len = struct.unpack("!H", wire_len)[0]
- query = dns.message.from_wire(self.request.recv(wire_len))
-
- # Echo service if no scenario
- response = dns.message.make_response(query)
- if self.server.scenario is not None:
- response = self.server.scenario.reply(query)
- if response:
- response = response.to_wire()
- self.request.send(struct.pack('!H', len(response)) + response)
-
+def send_message(stream, message):
+ """ Send DNS/TCP message. """
+ message = message.to_wire()
+ stream.send(struct.pack('!H', len(message)) + message)
class TestServer:
""" This simulates TCP DNS server returning prescripted or mirror DNS responses. """
- def __init__(self, scenario, host='127.0.0.1', port=0):
- self.server = SocketServer.TCPServer((host, port), DNSHandler)
- self.server.allow_reuse_address = True
- self.server.scenario = scenario
+ def __init__(self, scenario, type = socket.AF_UNIX, address = '.test_server.sock', port = 0):
+ """ Initialize server instance. """
+ self.is_active = False
+ self.thread = None
+ self.sock = socket.socket(type, socket.SOCK_STREAM)
+ self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ if type == socket.AF_UNIX:
+ if os.path.exists(address):
+ os.unlink(address)
+ self.sock.bind(address)
+ else:
+ self.sock.bind((address, port))
+ self.sock.listen(5)
+ self.sock_type = type
+ self.scenario = scenario
+
+ def __del__(self):
+ """ Cleanup after deletion. """
+ if self.is_active:
+ self.stop()
+
+ def handle(self, client):
+ """ Handle incoming queries. """
+ query = recv_message(client)
+ if query is None:
+ return False
+ response = dns.message.make_response(query)
+ if self.scenario is not None:
+ response = self.scenario.reply(query)
+ if response:
+ send_message(client, response)
+ return True
def start(self):
""" Asynchronous start, returns immediately. """
- self.thread = threading.Thread(target=self.run)
+ if self.is_active is True:
+ raise Exception('server already started')
+ self.is_active = True
+ self.thread = threading.Thread(target = self.run)
self.thread.start()
def run(self):
""" Synchronous start, waits until server closes or for an interrupt. """
- self.server.serve_forever()
+ self.is_active = True
+ clients = [self.sock]
+ while self.is_active and len(clients):
+ to_read, _, to_error = select.select(clients, [], clients, 0.5)
+ for sock in to_read:
+ if sock == self.sock:
+ clients.append(sock.accept()[0])
+ else:
+ if not self.handle(sock):
+ to_error.append(sock)
+ for sock in to_error:
+ clients.remove(sock)
+ sock.close()
def stop(self):
""" Stop socket server operation. """
- self.server.shutdown()
+ self.is_active = False
+ if self.thread is not None:
+ print 'waiting to stop'
+ self.thread.join()
+ self.thread = None
+ if self.sock_type == socket.AF_UNIX:
+ address = self.sock.getsockname()
+ if os.path.exists(address):
+ os.remove(address)
def client(self):
""" Return connected client. """
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.connect(self.server.server_address)
+ sock = socket.socket(self.sock_type, socket.SOCK_STREAM)
+ sock.connect(self.address())
return sock
+ def address(self):
+ """ Return bound address. """
+ return self.sock.getsockname()
+
if __name__ == '__main__':
- server = TestServer(None)
- print('mirror server running at %s' % str(server.server.server_address))
- try:
- server.run()
- except KeyboardInterrupt:
- pass
- server.stop()
\ No newline at end of file
+ if '--test' in sys.argv:
+ server = TestServer(None)
+ client = server.client()
+ server.start()
+ try:
+ query = dns.message.make_query('.', dns.rdatatype.NS)
+ send_message(client, query)
+ answer = recv_message(client)
+ if answer is None:
+ raise Exception('no answer received')
+ if not query.is_response(answer):
+ raise Exception('not a mirror response')
+ print('[ OK ] testserver')
+ except Exception as e:
+ print('[FAIL] testserver %s' % str(e))
+ finally:
+ client.close()
+ server.stop()
+
+ else:
+ server = TestServer(None, socket.AF_INET, '127.0.0.1')
+ print('mirror server running at %s' % str(server.address()))
+ try:
+ server.run()
+ except KeyboardInterrupt:
+ pass
+ server.stop()
/*
* Globals
*/
-mm_ctx_t global_mm; /* Test memory context */
-struct kr_context global_context; /* Resolution context */
-const char *global_tmpdir = NULL; /* Temporary directory */
-struct timeval _mock_time; /* Mocked system time */
-int _mock_fd; /* Mocked endpoint for recursive queries */
+static mm_ctx_t global_mm; /* Test memory context */
+static struct kr_context global_context; /* Resolution context */
+static const char *global_tmpdir = NULL; /* Temporary directory */
+static struct timeval _mock_time; /* Mocked system time */
+static PyObject *mock_server = NULL; /* Mocked endpoint for recursive queries */
/*
* PyModule implementation.
{
/* Initialize mock variables */
memset(&_mock_time, 0, sizeof(struct timeval));
- _mock_fd = -1;
+ mock_server = NULL;
/* Initialize resolution context */
#define CACHE_SIZE 100*1024
kr_context_deinit(&global_context);
test_tmpdir_remove(global_tmpdir);
global_tmpdir = NULL;
- _mock_fd = -1;
+ if (mock_server) {
+ Py_XDECREF(mock_server);
+ mock_server = NULL;
+ }
return Py_BuildValue("");
}
return Py_BuildValue("");
}
-static PyObject* set_endpoint(PyObject *self, PyObject *args)
+static PyObject* set_server(PyObject *self, PyObject *args)
{
- PyObject *arg_socket = NULL;
- if (!PyArg_ParseTuple(args, "O", &arg_socket)) {
+ /* Get client socket getter method. */
+ PyObject *arg_client = NULL;
+ if (!PyArg_ParseTuple(args, "O", &arg_client)) {
return NULL;
}
- int fd = PyObject_AsFileDescriptor(arg_socket);
- if (fd < 0) {
- return NULL;
- }
+ /* Swap the server implementation. */
+ Py_XINCREF(arg_client);
+ Py_XDECREF(mock_server);
+ mock_server = arg_client;
- _mock_fd = fd;
return Py_BuildValue("");
}
{"deinit", deinit, METH_VARARGS, "Clean up resolution context."},
{"resolve", resolve, METH_VARARGS, "Resolve query."},
{"set_time", set_time, METH_VARARGS, "Set mock system time."},
- {"set_endpoint", set_endpoint, METH_VARARGS, "Set endpoint for recursive queries."},
+ {"set_server", set_server, METH_VARARGS, "Set fake server object."},
{NULL, NULL, 0, NULL}
};
return 0;
}
-int net_unbound_socket(int type, const struct sockaddr_storage *ss)
+int udp_recv_msg(int fd, uint8_t *buf, size_t len, struct sockaddr *addr)
{
- char addr_str[SOCKADDR_STRLEN];
- sockaddr_tostr(addr_str, sizeof(addr_str), ss);
- fprintf(stderr, "%s (%d, %s)\n", __func__, type, addr_str);
- return _mock_fd;
+ /* Force TCP, as we're tunelling. */
+ return tcp_recv_msg(fd, buf, len, NULL);
}
-int net_bound_socket(int type, const struct sockaddr_storage *ss)
+
+int udp_send_msg(int fd, const uint8_t *msg, size_t msglen,
+ const struct sockaddr *addr)
{
- char addr_str[SOCKADDR_STRLEN];
- sockaddr_tostr(addr_str, sizeof(addr_str), ss);
- fprintf(stderr, "%s (%d, %s)\n", __func__, type, addr_str);
- return _mock_fd;
+ /* Force TCP, as we're tunelling. */
+ return tcp_send_msg(fd, msg, msglen);
}
+
int net_connected_socket(int type, const struct sockaddr_storage *dst_addr,
const struct sockaddr_storage *src_addr, unsigned flags)
{
sockaddr_tostr(dst_addr_str, sizeof(dst_addr_str), dst_addr);
sockaddr_tostr(src_addr_str, sizeof(src_addr_str), src_addr);
fprintf(stderr, "%s (%d, %s, %s, %u)\n", __func__, type, dst_addr_str, src_addr_str, flags);
- return _mock_fd;
+
+ PyObject *result = PyObject_CallMethod(mock_server, "client", "");
+ if (result == NULL) {
+ return -1;
+ }
+
+ int fd = dup(PyObject_AsFileDescriptor(result));
+ Py_DECREF(result);
+ return fd;
}
int net_is_connected(int fd)