]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
tests/integrity: pass server instance, rewritten testserver
authorMarek Vavruša <marek.vavrusa@nic.cz>
Sun, 18 Jan 2015 14:26:21 +0000 (15:26 +0100)
committerMarek Vavruša <marek.vavrusa@nic.cz>
Sun, 18 Jan 2015 20:10:43 +0000 (21:10 +0100)
tests/pydnstest/testserver.py
tests/test_integration.c
tests/test_integration.py

index a0c42df2f959e5b7ff28e888bc8a62072094d0c7..59d52a0d9469ac3c545a09967ef8b7e724a12322 100644 (file)
-import SocketServer, socket, threading, struct
+import select, socket, threading, struct, sys, os
 import dns.message
 
+def recv_message(stream):
+    """ Receive DNS/TCP message. """
+    wire_len = stream.recv(2)
+    if len(wire_len) != 2:
+       return None
+    wire_len = struct.unpack("!H", wire_len)[0]
+    return dns.message.from_wire(stream.recv(wire_len))
 
-class DNSHandler(SocketServer.BaseRequestHandler):
-    """ This handler returns prescripted or mirror DNS responses. """
-
-    def handle(self):
-        """ Handle incoming queries. """
-        wire_len = self.request.recv(2)
-        if len(wire_len) != 2:
-            return
-        wire_len = struct.unpack("!H", wire_len)[0]
-        query = dns.message.from_wire(self.request.recv(wire_len))
-
-        # Echo service if no scenario
-        response = dns.message.make_response(query)
-        if self.server.scenario is not None:
-            response = self.server.scenario.reply(query)
-        if response:
-            response = response.to_wire()
-            self.request.send(struct.pack('!H', len(response)) + response)
-
+def send_message(stream, message):
+    """ Send DNS/TCP message. """
+    message = message.to_wire()
+    stream.send(struct.pack('!H', len(message)) + message)
 
 class TestServer:
     """ This simulates TCP DNS server returning prescripted or mirror DNS responses. """
 
-    def __init__(self, scenario, host='127.0.0.1', port=0):
-        self.server = SocketServer.TCPServer((host, port), DNSHandler)
-        self.server.allow_reuse_address = True
-        self.server.scenario = scenario
+    def __init__(self, scenario, type = socket.AF_UNIX, address = '.test_server.sock', port = 0):
+        """ Initialize server instance. """
+        self.is_active = False
+        self.thread = None
+        self.sock = socket.socket(type, socket.SOCK_STREAM)
+        self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+        if type == socket.AF_UNIX:
+            if os.path.exists(address):
+                os.unlink(address)
+            self.sock.bind(address)
+        else:
+            self.sock.bind((address, port))
+        self.sock.listen(5)
+        self.sock_type = type
+        self.scenario = scenario
+
+    def __del__(self):
+        """ Cleanup after deletion. """
+        if self.is_active:
+            self.stop()
+
+    def handle(self, client):
+        """ Handle incoming queries. """
+        query = recv_message(client)
+        if query is None:
+            return False
+        response = dns.message.make_response(query)
+        if self.scenario is not None:
+            response = self.scenario.reply(query)
+        if response:
+            send_message(client, response)
+        return True
 
     def start(self):
         """ Asynchronous start, returns immediately. """
-        self.thread = threading.Thread(target=self.run)
+        if self.is_active is True:
+            raise Exception('server already started')
+        self.is_active = True
+        self.thread = threading.Thread(target = self.run)
         self.thread.start()
 
     def run(self):
         """ Synchronous start, waits until server closes or for an interrupt. """
-        self.server.serve_forever()
+        self.is_active = True
+        clients = [self.sock]
+        while self.is_active and len(clients):
+            to_read, _, to_error = select.select(clients, [], clients, 0.5)
+            for sock in to_read:
+                if sock == self.sock:
+                    clients.append(sock.accept()[0])
+                else:
+                    if not self.handle(sock):
+                        to_error.append(sock)
+            for sock in to_error:
+                clients.remove(sock)
+                sock.close()
 
     def stop(self):
         """ Stop socket server operation. """
-        self.server.shutdown()
+        self.is_active = False
+        if self.thread is not None:
+            print 'waiting to stop'
+            self.thread.join()
+            self.thread = None
+        if self.sock_type == socket.AF_UNIX:
+            address = self.sock.getsockname()
+            if os.path.exists(address):
+                os.remove(address)
 
     def client(self):
         """ Return connected client. """
-        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-        sock.connect(self.server.server_address)
+        sock = socket.socket(self.sock_type, socket.SOCK_STREAM)
+        sock.connect(self.address())
         return sock
 
+    def address(self):
+        """ Return bound address. """
+        return self.sock.getsockname()
+
 
 if __name__ == '__main__':
 
-    server = TestServer(None)
-    print('mirror server running at %s' % str(server.server.server_address))
-    try:
-        server.run()
-    except KeyboardInterrupt:
-        pass
-    server.stop()
\ No newline at end of file
+    if '--test' in sys.argv:
+        server = TestServer(None)
+        client = server.client()
+        server.start()
+        try:
+            query = dns.message.make_query('.', dns.rdatatype.NS)
+            send_message(client, query)
+            answer = recv_message(client)
+            if answer is None:
+                raise Exception('no answer received')
+            if not query.is_response(answer):
+                raise Exception('not a mirror response')
+            print('[ OK ] testserver')
+        except Exception as e:
+            print('[FAIL] testserver %s' % str(e))
+        finally:
+            client.close()
+            server.stop()
+
+    else:
+        server = TestServer(None, socket.AF_INET, '127.0.0.1')
+        print('mirror server running at %s' % str(server.address()))
+        try:
+            server.run()
+        except KeyboardInterrupt:
+            pass
+        server.stop()
index 04b9225ba500f5eb541863ec18bd24bd72b62f3e..946473aa249f87ddc48e33858450802207235109 100644 (file)
 /*
  * Globals
  */
-mm_ctx_t global_mm;               /* Test memory context */
-struct kr_context global_context; /* Resolution context */
-const char *global_tmpdir = NULL; /* Temporary directory */
-struct timeval _mock_time;        /* Mocked system time */
-int _mock_fd;                     /* Mocked endpoint for recursive queries */
+static mm_ctx_t global_mm;               /* Test memory context */
+static struct kr_context global_context; /* Resolution context */
+static const char *global_tmpdir = NULL; /* Temporary directory */
+static struct timeval _mock_time;        /* Mocked system time */
+static PyObject *mock_server  = NULL;   /* Mocked endpoint for recursive queries */
 
 /*
  * PyModule implementation.
@@ -39,7 +39,7 @@ static PyObject* init(PyObject* self, PyObject* args)
 {
        /* Initialize mock variables */
        memset(&_mock_time, 0, sizeof(struct timeval));
-       _mock_fd = -1;
+       mock_server = NULL;
 
        /* Initialize resolution context */
        #define CACHE_SIZE 100*1024
@@ -62,7 +62,10 @@ static PyObject* deinit(PyObject* self, PyObject* args)
        kr_context_deinit(&global_context);
        test_tmpdir_remove(global_tmpdir);
        global_tmpdir = NULL;
-       _mock_fd = -1;
+       if (mock_server) {
+               Py_XDECREF(mock_server);
+               mock_server = NULL;
+       }
 
        return Py_BuildValue("");
 }
@@ -111,19 +114,19 @@ static PyObject* set_time(PyObject *self, PyObject *args)
        return Py_BuildValue("");
 }
 
-static PyObject* set_endpoint(PyObject *self, PyObject *args)
+static PyObject* set_server(PyObject *self, PyObject *args)
 {
-       PyObject *arg_socket = NULL;
-       if (!PyArg_ParseTuple(args, "O", &arg_socket)) {
+       /* Get client socket getter method. */
+       PyObject *arg_client = NULL;
+       if (!PyArg_ParseTuple(args, "O", &arg_client)) {
                return NULL;
        }
 
-       int fd = PyObject_AsFileDescriptor(arg_socket);
-       if (fd < 0) {
-               return NULL;
-       }
+       /* Swap the server implementation. */
+       Py_XINCREF(arg_client);
+       Py_XDECREF(mock_server);
+       mock_server = arg_client;
 
-       _mock_fd = fd;
        return Py_BuildValue("");
 }
 
@@ -132,7 +135,7 @@ static PyMethodDef module_methods[] = {
     {"deinit", deinit, METH_VARARGS, "Clean up resolution context."},
     {"resolve", resolve, METH_VARARGS, "Resolve query."},
     {"set_time", set_time, METH_VARARGS, "Set mock system time."},
-    {"set_endpoint", set_endpoint, METH_VARARGS, "Set endpoint for recursive queries."},
+    {"set_server", set_server, METH_VARARGS, "Set fake server object."},
     {NULL, NULL, 0, NULL}
 };
 
@@ -153,22 +156,21 @@ int __wrap_gettimeofday(struct timeval *tv, struct timezone *tz)
        return 0;
 }
 
-int net_unbound_socket(int type, const struct sockaddr_storage *ss)
+int udp_recv_msg(int fd, uint8_t *buf, size_t len, struct sockaddr *addr)
 {
-       char addr_str[SOCKADDR_STRLEN];
-       sockaddr_tostr(addr_str, sizeof(addr_str), ss);
-       fprintf(stderr, "%s (%d, %s)\n", __func__, type, addr_str);
-       return _mock_fd;
+       /* Force TCP, as we're tunelling. */
+       return tcp_recv_msg(fd, buf, len, NULL);
 }
 
-int net_bound_socket(int type, const struct sockaddr_storage *ss)
+
+int udp_send_msg(int fd, const uint8_t *msg, size_t msglen,
+                 const struct sockaddr *addr)
 {
-       char addr_str[SOCKADDR_STRLEN];
-       sockaddr_tostr(addr_str, sizeof(addr_str), ss);
-       fprintf(stderr, "%s (%d, %s)\n", __func__, type, addr_str);
-       return _mock_fd;
+       /* Force TCP, as we're tunelling. */
+       return tcp_send_msg(fd, msg, msglen);
 }
 
+
 int net_connected_socket(int type, const struct sockaddr_storage *dst_addr,
                          const struct sockaddr_storage *src_addr, unsigned flags)
 {
@@ -176,7 +178,15 @@ int net_connected_socket(int type, const struct sockaddr_storage *dst_addr,
        sockaddr_tostr(dst_addr_str, sizeof(dst_addr_str), dst_addr);
        sockaddr_tostr(src_addr_str, sizeof(src_addr_str), src_addr);
        fprintf(stderr, "%s (%d, %s, %s, %u)\n", __func__, type, dst_addr_str, src_addr_str, flags);
-       return _mock_fd;
+
+       PyObject *result = PyObject_CallMethod(mock_server, "client", "");
+       if (result == NULL) {
+               return -1;
+       }
+
+       int fd = dup(PyObject_AsFileDescriptor(result));
+       Py_DECREF(result);
+       return fd;
 }
 
 int net_is_connected(int fd)
index 0051a10a23ff626ec3f1a22ba6462e058195919b..eaa64a7ba4a2b9ff637c80bf39df78d502f56228 100755 (executable)
@@ -126,21 +126,17 @@ def play_object(path):
     server = testserver.TestServer(scenario)
     server.start()
     mock_ctx.init()
-    client = None
+    mock_ctx.set_server(server)
     try:
         if TEST_DEBUG > 0:
             print('--- server listening at %s ---' % str(server.server.server_address))
             print('--- scenario parsed, any key to continue ---')
             sys.stdin.readline()
-        client = server.client()
-        mock_ctx.set_endpoint(client)
         scenario.play(mock_ctx)
         print('%s OK' % os.path.basename(path))
     except Exception as e:
         print('%s %s' % (os.path.basename(path), str(e)))
     finally:
-        if client is not None:
-            client.close()
         server.stop()
         mock_ctx.deinit()