From: Marek VavruĊĦa Date: Sun, 18 Jan 2015 14:26:21 +0000 (+0100) Subject: tests/integrity: pass server instance, rewritten testserver X-Git-Tag: v1.0.0-beta1~363^2~8 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=bce64274e5a2dc9ab54f70d0cd215134ccdaf9b0;p=thirdparty%2Fknot-resolver.git tests/integrity: pass server instance, rewritten testserver --- diff --git a/tests/pydnstest/testserver.py b/tests/pydnstest/testserver.py index a0c42df2f..59d52a0d9 100644 --- a/tests/pydnstest/testserver.py +++ b/tests/pydnstest/testserver.py @@ -1,61 +1,128 @@ -import SocketServer, socket, threading, struct +import select, socket, threading, struct, sys, os import dns.message +def recv_message(stream): + """ Receive DNS/TCP message. """ + wire_len = stream.recv(2) + if len(wire_len) != 2: + return None + wire_len = struct.unpack("!H", wire_len)[0] + return dns.message.from_wire(stream.recv(wire_len)) -class DNSHandler(SocketServer.BaseRequestHandler): - """ This handler returns prescripted or mirror DNS responses. """ - - def handle(self): - """ Handle incoming queries. """ - wire_len = self.request.recv(2) - if len(wire_len) != 2: - return - wire_len = struct.unpack("!H", wire_len)[0] - query = dns.message.from_wire(self.request.recv(wire_len)) - - # Echo service if no scenario - response = dns.message.make_response(query) - if self.server.scenario is not None: - response = self.server.scenario.reply(query) - if response: - response = response.to_wire() - self.request.send(struct.pack('!H', len(response)) + response) - +def send_message(stream, message): + """ Send DNS/TCP message. """ + message = message.to_wire() + stream.send(struct.pack('!H', len(message)) + message) class TestServer: """ This simulates TCP DNS server returning prescripted or mirror DNS responses. """ - def __init__(self, scenario, host='127.0.0.1', port=0): - self.server = SocketServer.TCPServer((host, port), DNSHandler) - self.server.allow_reuse_address = True - self.server.scenario = scenario + def __init__(self, scenario, type = socket.AF_UNIX, address = '.test_server.sock', port = 0): + """ Initialize server instance. """ + self.is_active = False + self.thread = None + self.sock = socket.socket(type, socket.SOCK_STREAM) + self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if type == socket.AF_UNIX: + if os.path.exists(address): + os.unlink(address) + self.sock.bind(address) + else: + self.sock.bind((address, port)) + self.sock.listen(5) + self.sock_type = type + self.scenario = scenario + + def __del__(self): + """ Cleanup after deletion. """ + if self.is_active: + self.stop() + + def handle(self, client): + """ Handle incoming queries. """ + query = recv_message(client) + if query is None: + return False + response = dns.message.make_response(query) + if self.scenario is not None: + response = self.scenario.reply(query) + if response: + send_message(client, response) + return True def start(self): """ Asynchronous start, returns immediately. """ - self.thread = threading.Thread(target=self.run) + if self.is_active is True: + raise Exception('server already started') + self.is_active = True + self.thread = threading.Thread(target = self.run) self.thread.start() def run(self): """ Synchronous start, waits until server closes or for an interrupt. """ - self.server.serve_forever() + self.is_active = True + clients = [self.sock] + while self.is_active and len(clients): + to_read, _, to_error = select.select(clients, [], clients, 0.5) + for sock in to_read: + if sock == self.sock: + clients.append(sock.accept()[0]) + else: + if not self.handle(sock): + to_error.append(sock) + for sock in to_error: + clients.remove(sock) + sock.close() def stop(self): """ Stop socket server operation. """ - self.server.shutdown() + self.is_active = False + if self.thread is not None: + print 'waiting to stop' + self.thread.join() + self.thread = None + if self.sock_type == socket.AF_UNIX: + address = self.sock.getsockname() + if os.path.exists(address): + os.remove(address) def client(self): """ Return connected client. """ - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.connect(self.server.server_address) + sock = socket.socket(self.sock_type, socket.SOCK_STREAM) + sock.connect(self.address()) return sock + def address(self): + """ Return bound address. """ + return self.sock.getsockname() + if __name__ == '__main__': - server = TestServer(None) - print('mirror server running at %s' % str(server.server.server_address)) - try: - server.run() - except KeyboardInterrupt: - pass - server.stop() \ No newline at end of file + if '--test' in sys.argv: + server = TestServer(None) + client = server.client() + server.start() + try: + query = dns.message.make_query('.', dns.rdatatype.NS) + send_message(client, query) + answer = recv_message(client) + if answer is None: + raise Exception('no answer received') + if not query.is_response(answer): + raise Exception('not a mirror response') + print('[ OK ] testserver') + except Exception as e: + print('[FAIL] testserver %s' % str(e)) + finally: + client.close() + server.stop() + + else: + server = TestServer(None, socket.AF_INET, '127.0.0.1') + print('mirror server running at %s' % str(server.address())) + try: + server.run() + except KeyboardInterrupt: + pass + server.stop() diff --git a/tests/test_integration.c b/tests/test_integration.c index 04b9225ba..946473aa2 100644 --- a/tests/test_integration.c +++ b/tests/test_integration.c @@ -25,11 +25,11 @@ /* * Globals */ -mm_ctx_t global_mm; /* Test memory context */ -struct kr_context global_context; /* Resolution context */ -const char *global_tmpdir = NULL; /* Temporary directory */ -struct timeval _mock_time; /* Mocked system time */ -int _mock_fd; /* Mocked endpoint for recursive queries */ +static mm_ctx_t global_mm; /* Test memory context */ +static struct kr_context global_context; /* Resolution context */ +static const char *global_tmpdir = NULL; /* Temporary directory */ +static struct timeval _mock_time; /* Mocked system time */ +static PyObject *mock_server = NULL; /* Mocked endpoint for recursive queries */ /* * PyModule implementation. @@ -39,7 +39,7 @@ static PyObject* init(PyObject* self, PyObject* args) { /* Initialize mock variables */ memset(&_mock_time, 0, sizeof(struct timeval)); - _mock_fd = -1; + mock_server = NULL; /* Initialize resolution context */ #define CACHE_SIZE 100*1024 @@ -62,7 +62,10 @@ static PyObject* deinit(PyObject* self, PyObject* args) kr_context_deinit(&global_context); test_tmpdir_remove(global_tmpdir); global_tmpdir = NULL; - _mock_fd = -1; + if (mock_server) { + Py_XDECREF(mock_server); + mock_server = NULL; + } return Py_BuildValue(""); } @@ -111,19 +114,19 @@ static PyObject* set_time(PyObject *self, PyObject *args) return Py_BuildValue(""); } -static PyObject* set_endpoint(PyObject *self, PyObject *args) +static PyObject* set_server(PyObject *self, PyObject *args) { - PyObject *arg_socket = NULL; - if (!PyArg_ParseTuple(args, "O", &arg_socket)) { + /* Get client socket getter method. */ + PyObject *arg_client = NULL; + if (!PyArg_ParseTuple(args, "O", &arg_client)) { return NULL; } - int fd = PyObject_AsFileDescriptor(arg_socket); - if (fd < 0) { - return NULL; - } + /* Swap the server implementation. */ + Py_XINCREF(arg_client); + Py_XDECREF(mock_server); + mock_server = arg_client; - _mock_fd = fd; return Py_BuildValue(""); } @@ -132,7 +135,7 @@ static PyMethodDef module_methods[] = { {"deinit", deinit, METH_VARARGS, "Clean up resolution context."}, {"resolve", resolve, METH_VARARGS, "Resolve query."}, {"set_time", set_time, METH_VARARGS, "Set mock system time."}, - {"set_endpoint", set_endpoint, METH_VARARGS, "Set endpoint for recursive queries."}, + {"set_server", set_server, METH_VARARGS, "Set fake server object."}, {NULL, NULL, 0, NULL} }; @@ -153,22 +156,21 @@ int __wrap_gettimeofday(struct timeval *tv, struct timezone *tz) return 0; } -int net_unbound_socket(int type, const struct sockaddr_storage *ss) +int udp_recv_msg(int fd, uint8_t *buf, size_t len, struct sockaddr *addr) { - char addr_str[SOCKADDR_STRLEN]; - sockaddr_tostr(addr_str, sizeof(addr_str), ss); - fprintf(stderr, "%s (%d, %s)\n", __func__, type, addr_str); - return _mock_fd; + /* Force TCP, as we're tunelling. */ + return tcp_recv_msg(fd, buf, len, NULL); } -int net_bound_socket(int type, const struct sockaddr_storage *ss) + +int udp_send_msg(int fd, const uint8_t *msg, size_t msglen, + const struct sockaddr *addr) { - char addr_str[SOCKADDR_STRLEN]; - sockaddr_tostr(addr_str, sizeof(addr_str), ss); - fprintf(stderr, "%s (%d, %s)\n", __func__, type, addr_str); - return _mock_fd; + /* Force TCP, as we're tunelling. */ + return tcp_send_msg(fd, msg, msglen); } + int net_connected_socket(int type, const struct sockaddr_storage *dst_addr, const struct sockaddr_storage *src_addr, unsigned flags) { @@ -176,7 +178,15 @@ int net_connected_socket(int type, const struct sockaddr_storage *dst_addr, sockaddr_tostr(dst_addr_str, sizeof(dst_addr_str), dst_addr); sockaddr_tostr(src_addr_str, sizeof(src_addr_str), src_addr); fprintf(stderr, "%s (%d, %s, %s, %u)\n", __func__, type, dst_addr_str, src_addr_str, flags); - return _mock_fd; + + PyObject *result = PyObject_CallMethod(mock_server, "client", ""); + if (result == NULL) { + return -1; + } + + int fd = dup(PyObject_AsFileDescriptor(result)); + Py_DECREF(result); + return fd; } int net_is_connected(int fd) diff --git a/tests/test_integration.py b/tests/test_integration.py index 0051a10a2..eaa64a7ba 100755 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -126,21 +126,17 @@ def play_object(path): server = testserver.TestServer(scenario) server.start() mock_ctx.init() - client = None + mock_ctx.set_server(server) try: if TEST_DEBUG > 0: print('--- server listening at %s ---' % str(server.server.server_address)) print('--- scenario parsed, any key to continue ---') sys.stdin.readline() - client = server.client() - mock_ctx.set_endpoint(client) scenario.play(mock_ctx) print('%s OK' % os.path.basename(path)) except Exception as e: print('%s %s' % (os.path.basename(path), str(e))) finally: - if client is not None: - client.close() server.stop() mock_ctx.deinit()