along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-#include <Python.h>
+#include <pthread.h>
+#include <dlfcn.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <netdb.h>
+#include <fcntl.h>
+
#include <libknot/descriptor.h>
#include <libknot/packet/pkt.h>
#include <libknot/internal/net.h>
+#include <Python.h>
+
+
/*
* Globals
*/
-struct timeval g_mock_time; /* Mocked system time */
-PyObject *g_mock_server = NULL; /* Mocked endpoint for recursive queries */
-
#ifdef __APPLE__
-int gettimeofday(struct timeval *tv, void *tz)
+ #define MOCK__TZ_ARG void
+ #define MOCK__SOCKADDR_ARG struct sockaddr *restrict
+ #define MOCK__CONST_SOCKADDR_ARG const struct sockaddr *
+ #define MOCK__GET_SOCKADDR(arg) arg
+ #define errno_location __error()
#else
-int gettimeofday(struct timeval *tv, struct timezone *tz)
+ #define MOCK__TZ_ARG struct timezone
+ #define MOCK__SOCKADDR_ARG __SOCKADDR_ARG
+ #define MOCK__CONST_SOCKADDR_ARG __CONST_SOCKADDR_ARG
+ #define MOCK__GET_SOCKADDR(arg) arg.__sockaddr__
+ #define errno_location __errno_location()
#endif
+
+struct timeval g_mock_time; /* Mocked system time */
+PyObject *g_mock_server = NULL; /* Mocked endpoint for recursive queries */
+
+struct sockaddr_storage original_dst = { 0 };
+int original_dst_len = 0;
+int connected_fd = -1;
+
+int (*original_connect)(int __fd, MOCK__CONST_SOCKADDR_ARG __addr,
+ socklen_t __len) = NULL;
+
+ssize_t (*original_recvfrom) (int __fd, void *__restrict __buf, size_t __n,
+ int __flags, MOCK__SOCKADDR_ARG __addr,
+ socklen_t *__restrict __addr_len) = NULL;
+
+ssize_t (*original_recv) (int __fd, void *__buf,
+ size_t __n, int __flags) = NULL;
+
+int (*original_select) (int __nfds, fd_set *__restrict __readfds,
+ fd_set *__restrict __writefds,
+ fd_set *__restrict __exceptfds,
+ struct timeval *__restrict __timeout) = NULL;
+
+#define FIND_ORIGINAL(fname) \
+ if (original_##fname == NULL) \
+ { \
+ original_##fname = dlsym(RTLD_NEXT,#fname);\
+ assert(original_##fname);\
+ }
+
+int gettimeofday(struct timeval *tv, MOCK__TZ_ARG *tz)
{
memcpy(tv, &g_mock_time, sizeof(struct timeval));
return 0;
}
-int tcp_recv_msg(int fd, uint8_t *buf, size_t len, struct timeval *timeout)
+ssize_t recvfrom (int __fd, void *__restrict __buf, size_t __n,
+ int __flags, MOCK__SOCKADDR_ARG __addr,
+ socklen_t *__restrict __addr_len)
{
- /* Unlock GIL and attempt to receive message. */
- uint16_t msg_len = 0;
- int rcvd = 0;
- Py_BEGIN_ALLOW_THREADS
- rcvd = read(fd, (char *)&msg_len, sizeof(msg_len));
- if (rcvd == sizeof(msg_len)) {
- msg_len = htons(msg_len);
- rcvd = read(fd, buf, msg_len);
+ ssize_t ret;
+ struct sockaddr *addr = MOCK__GET_SOCKADDR(__addr);
+ FIND_ORIGINAL(recvfrom);
+ if (__fd == connected_fd) {
+ if ((__flags & MSG_DONTWAIT) == 0) {
+ Py_BEGIN_ALLOW_THREADS
+ ret = original_recvfrom( __fd,__buf,__n,__flags,__addr,__addr_len);
+ Py_END_ALLOW_THREADS
+ }
+ else
+ ret = original_recvfrom( __fd,__buf,__n,__flags,__addr,__addr_len);
+ if (addr != NULL && *__addr_len > 0) {
+ int len = original_dst_len;
+ if (len < *__addr_len)
+ len = *__addr_len;
+ memcpy(addr, &original_dst, len);
+ }
}
- Py_END_ALLOW_THREADS
- return rcvd;
+ else
+ ret = original_recvfrom( __fd,__buf,__n,__flags,__addr,__addr_len);
+ return ret;
}
-int udp_recv_msg(int fd, uint8_t *buf, size_t len, struct timeval *timeout)
+ssize_t recv (int __fd, void *__buf, size_t __n, int __flags)
{
- /* Tunnel via TCP. */
- return tcp_recv_msg(fd, buf, len, timeout);
+ ssize_t ret;
+ FIND_ORIGINAL(recv);
+ if (__fd == connected_fd) {
+ if ((__flags & MSG_DONTWAIT) == 0) {
+ Py_BEGIN_ALLOW_THREADS
+ ret = original_recv (__fd,__buf,__n,__flags);
+ Py_END_ALLOW_THREADS
+ }
+ else
+ ret = original_recv (__fd,__buf,__n,__flags);
+ }
+ else
+ ret = original_recv (__fd,__buf,__n,__flags);
+ return ret;
}
-
-int tcp_send_msg(int fd, const uint8_t *msg, size_t len, struct timeval *timeout)
+int select (int __nfds, fd_set *__restrict __readfds,
+ fd_set *__restrict __writefds,
+ fd_set *__restrict __exceptfds,
+ struct timeval *__restrict __timeout)
{
- /* Unlock GIL and attempt to send message over. */
- uint16_t msg_len = htons(len);
- int sent = 0;
- Py_BEGIN_ALLOW_THREADS
- sent = write(fd, (char *)&msg_len, sizeof(msg_len));
- if (sent == sizeof(msg_len)) {
- sent = write(fd, msg, len);
+ int ret;
+ FIND_ORIGINAL(select);
+ if (connected_fd != -1 && __nfds > connected_fd && (
+ (__readfds != NULL && FD_ISSET(connected_fd, __readfds)) ||
+ (__writefds != NULL && FD_ISSET(connected_fd, __writefds)) ||
+ (__exceptfds != NULL && FD_ISSET(connected_fd, __exceptfds))
+ ))
+ {
+ struct timeval _timeout = {0, 200 * 1000};
+ Py_BEGIN_ALLOW_THREADS
+ ret = original_select (__nfds,
+ __readfds,__writefds,__exceptfds,&_timeout);
+ Py_END_ALLOW_THREADS
}
- Py_END_ALLOW_THREADS
- return sent;
+ else
+ ret = original_select (__nfds,
+ __readfds,__writefds,__exceptfds,__timeout);
+ return ret;
}
-int udp_send_msg(int fd, const uint8_t *msg, size_t msglen,
- const struct sockaddr *addr)
+int connect(int __fd, MOCK__CONST_SOCKADDR_ARG __addr, socklen_t __len)
{
- /* Tunnel via TCP. */
- return tcp_send_msg(fd, msg, msglen, NULL);
-}
+ Dl_info dli = {0};
+ char *python_addr;
+ struct addrinfo hints;
+ struct addrinfo *info = NULL;
+ int ret, parse_ret, python_port = 0, flowinfo, scopeid, local_socktype;
+ socklen_t local_socktypelen = sizeof(int);
+ const struct sockaddr *dst_addr = MOCK__GET_SOCKADDR(__addr);
+ char right_caller[] = "net_connected_socket";
+ PyObject *result;
+ char addr_str[SOCKADDR_STRLEN];
+ char pport[32];
+ FIND_ORIGINAL(connect);
+ dladdr (__builtin_return_address (0), &dli);
+ if (!dli.dli_sname ||
+ (strncmp(right_caller,dli.dli_sname,strlen(right_caller)) != 0))
+ return original_connect (__fd, __addr, __len);
-int net_connected_socket(int type, const struct sockaddr_storage *dst_addr,
- const struct sockaddr_storage *src_addr, unsigned flags)
-{
- char addr_str[SOCKADDR_STRLEN];
- sockaddr_tostr(addr_str, sizeof(addr_str), dst_addr);
+ sockaddr_tostr(addr_str, SOCKADDR_STRLEN,
+ (const struct sockaddr_storage *)dst_addr);
- PyObject *result = PyObject_CallMethod(g_mock_server, "client", "s", addr_str);
- if (result == NULL) {
+ if (dst_addr->sa_family != AF_INET && dst_addr->sa_family != AF_INET6) {
+ errno = EINVAL;
+ return -1;
+ }
+
+ getsockopt(__fd, SOL_SOCKET, SO_TYPE,
+ &local_socktype, &local_socktypelen);
+
+ if (local_socktype == SOCK_DGRAM) {
+ result = PyObject_CallMethod(g_mock_server, "get_server_socket",
+ "si", addr_str, dst_addr->sa_family);
+ if (result == NULL) {
+ errno = ECONNABORTED;
+ return -1;
+ }
+ }
+ else {
+ errno = EINVAL;
return -1;
}
- /* Refcount decrement is going to close the fd, dup() it */
- int fd = dup(PyObject_AsFileDescriptor(result));
+ if (dst_addr->sa_family == AF_INET)
+ parse_ret = PyArg_ParseTuple(result, "si",
+ &python_addr, &python_port);
+ else
+ parse_ret = PyArg_ParseTuple(result, "siii",
+ &python_addr, &python_port, &flowinfo, &scopeid);
+
Py_DECREF(result);
- return fd;
-}
-int net_is_connected(int fd)
-{
- return true;
+ if (!parse_ret) {
+ errno = ECONNABORTED;
+ return -1;
+ }
+
+ memset(&hints, 0, sizeof hints);
+ hints.ai_family = dst_addr->sa_family;
+ hints.ai_socktype = SOCK_DGRAM;
+ hints.ai_flags = AI_PASSIVE;
+ hints.ai_protocol = IPPROTO_UDP;
+ sprintf (pport,"%i",python_port);
+ ret = getaddrinfo(python_addr,pport,&hints,&info);
+ if (ret) {
+ errno = ECONNABORTED;
+ return -1;
+ }
+
+ connected_fd = __fd;
+ ret = original_connect (__fd, info->ai_addr, info->ai_addrlen);
+ freeaddrinfo(info);
+ memcpy(&original_dst,dst_addr,__len);
+ original_dst_len = __len;
+ return ret;
}
+
-import select, socket, threading, struct, sys, os
+import threading
+import select, socket, struct, sys, os, time
import dns.message
import test
-def recv_message(stream):
- """ Receive DNS/TCP message. """
- wire_len = stream.recv(2)
- if len(wire_len) != 2:
- return None
- wire_len = struct.unpack("!H", wire_len)[0]
- return dns.message.from_wire(stream.recv(wire_len))
+# Test debugging
+TEST_DEBUG = 0
+if 'TEST_DEBUG' in os.environ:
+ TEST_DEBUG = int(os.environ['TEST_DEBUG'])
-def send_message(stream, message):
- """ Send DNS/TCP message. """
+g_lock = threading.Lock()
+def syn_message(*args):
+ g_lock.acquire()
+ print args
+ g_lock.release()
+
+def recvfrom_message(stream):
+ """ Receive DNS/UDP message. """
+ if TEST_DEBUG > 0:
+ syn_message("incoming data")
+ data, addr = stream.recvfrom(8000)
+ if TEST_DEBUG > 0:
+ syn_message("[Python] received", len(data), "bytes from", addr)
+ return dns.message.from_wire(data), addr
+
+def sendto_message(stream, message, addr):
+ """ Send DNS/UDP message. """
+ if TEST_DEBUG > 0:
+ syn_message("outgoing data")
message = message.to_wire()
- stream.send(struct.pack('!H', len(message)) + message)
+ stream.sendto(message, addr)
+ if TEST_DEBUG > 0:
+ syn_message("[Python] sent", len(message), "bytes to", addr)
+
+class SInfo:
+ def __init__(self,type,addr,port,client_addr):
+ self.type = type
+ self.addr = addr
+ self.port = port
+ self.client_addr = client_addr
+ self.thread = None
+ self.active = False
+ self.name = ''
class TestServer:
- """ This simulates TCP DNS server returning scripted or mirror DNS responses. """
+ """ This simulates UDP DNS server returning scripted or mirror DNS responses. """
- def __init__(self, scenario, type = socket.AF_UNIX, address = '.test_server.sock', port = 0):
+ def __init__(self, scenario):
""" Initialize server instance. """
- self.is_active = False
- self.thread = None
- self.client_address = None
- self.sock = socket.socket(type, socket.SOCK_STREAM)
- self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- if type == socket.AF_UNIX:
- if os.path.exists(address):
- os.unlink(address)
- self.sock.bind(address)
- else:
- self.sock.bind((address, port))
- self.sock.listen(5)
- self.sock_type = type
+ if TEST_DEBUG > 0:
+ print "Test Server initialization"
+ self.srv_socks = []
+ self.client_socks = []
+ self.active = False
self.scenario = scenario
def __del__(self):
""" Cleanup after deletion. """
- if self.is_active:
+ if TEST_DEBUG > 0:
+ print "Test Server cleanup"
+ if self.active is True:
self.stop()
- def handle(self, client):
+ def start(self):
+ """ Asynchronous start, returns immediately. """
+ if TEST_DEBUG > 0:
+ print "Test Server start"
+ if self.active is True:
+ raise Exception('server already started')
+ self.active = True
+ self.get_server_socket(None, socket.AF_INET)
+ self.get_server_socket(None, socket.AF_INET6)
+
+ def stop(self):
+ """ Stop socket server operation. """
+ if TEST_DEBUG > 0:
+ syn_message("Test Server stop")
+ self.active = False
+ for srv_sock in self.srv_socks:
+ if TEST_DEBUG > 0:
+ syn_message("closing socket", srv_sock.name)
+ srv_sock.active = False
+ srv_sock.thread.join()
+ for client_sock in self.client_socks:
+ if TEST_DEBUG > 0:
+ syn_message("closing client socket")
+ client_sock.close()
+ self.client_socks = []
+ self.srv_socks = []
+ self.scenario = None
+ if TEST_DEBUG > 0:
+ syn_message("server stopped")
+
+ def address(self):
+ addrlist = [];
+ for s in self.srv_socks:
+ addrlist.append(s.name);
+ return addrlist;
+
+ def handle_query(self, client, client_address):
""" Handle incoming queries. """
- query = recv_message(client)
+ query, addr = recvfrom_message(client)
+ if TEST_DEBUG > 0:
+ syn_message("incoming query from", addr, "client", client_address)
+ if TEST_DEBUG > 1:
+ syn_message("=========\n",query,"=========")
if query is None:
+ if TEST_DEBUG > 0:
+ syn_message("Empty query")
return False
response = dns.message.make_response(query)
if self.scenario is not None:
- response = self.scenario.reply(query, self.client_address)
+ if TEST_DEBUG > 0:
+ syn_message("get scenario reply")
+ response = self.scenario.reply(query, client_address)
if response:
- send_message(client, response)
+ if TEST_DEBUG > 0:
+ syn_message("sending answer")
+ if TEST_DEBUG > 1:
+ syn_message("=========\n",response,"=========")
+ sendto_message(client, response, addr)
return True
else:
+ if TEST_DEBUG > 0:
+ syn_message("response is NULL")
return False
- def start(self):
- """ Asynchronous start, returns immediately. """
- if self.is_active is True:
- raise Exception('server already started')
- self.is_active = True
- self.thread = threading.Thread(target = self.run)
- self.thread.start()
-
- def run(self):
- """ Synchronous start, waits until server closes or for an interrupt. """
- self.is_active = True
- clients = [self.sock]
- while self.is_active and len(clients):
- to_read, _, to_error = select.select(clients, [], clients, 0.1)
- for sock in to_read:
- if sock == self.sock:
- client_sock, _ = sock.accept()
- clients.append(client_sock)
- else:
- if not self.handle(sock):
- to_error.append(sock)
- for sock in to_error:
- clients.remove(sock)
- sock.close()
+ def query_io(self,srv_sock):
+ """ Main server process """
+ if TEST_DEBUG > 0:
+ syn_message("query_io starts")
+ if self.active is False:
+ raise Exception('Test server not active')
+ res = socket.getaddrinfo(srv_sock.addr,srv_sock.port,srv_sock.type,0,socket.IPPROTO_UDP)
+ serv_sock = socket.socket(srv_sock.type, socket.SOCK_DGRAM,socket.IPPROTO_UDP)
+ entry0 = res[0]
+ sockaddr = entry0[4]
+ serv_sock.bind(sockaddr)
+ serv_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ address = serv_sock.getsockname()
+ srv_sock.name = address
+ clients = [serv_sock]
+ srv_sock.active = True
+ if TEST_DEBUG > 0:
+ syn_message("UDP query handler type", srv_sock.type, "started at", address)
+ while srv_sock.active is True:
+ to_read, _, to_error = select.select(clients, [], clients, 0.1)
+ for sock in to_read:
+ self.handle_query(sock,srv_sock.client_addr)
+ for sock in to_error:
+ raise Exception('Socket IO error, exit')
+ serv_sock.close()
+ if TEST_DEBUG > 0:
+ syn_message("UDP query handler exit")
- def stop(self):
- """ Stop socket server operation. """
- self.is_active = False
- if self.thread is not None:
- self.thread.join()
- self.thread = None
- if self.sock_type == socket.AF_UNIX:
- address = self.sock.getsockname()
- if os.path.exists(address):
- os.remove(address)
-
- def client(self, dst_address = None):
- """ Return connected client. """
- if dst_address is not None:
- self.client_address = dst_address.split('@')[0]
- sock = socket.socket(self.sock_type, socket.SOCK_STREAM)
- sock.connect(self.sock.getsockname())
- return sock
- def address(self):
- """ Return bound address. """
- address = self.sock.getsockname()
- if self.sock_type == socket.AF_UNIX:
- address = (address, 0)
- return address
+ def get_server_socket(self, client_addr, type = socket.AF_INET, address = None, port = 0):
+ if TEST_DEBUG > 0:
+ syn_message("getting server socket type",type,client_addr)
+ if client_addr is not None:
+ client_addr = client_addr.split('@')[0]
+ if type == socket.AF_INET:
+ if address is None:
+ address = '127.0.0.1'
+ elif type == socket.AF_INET6:
+ if socket.has_ipv6 is not True:
+ raise Exception('IPV6 is no supported')
+ if address is None:
+ address = "::1"
+ else:
+ print "unsupported socket type", self.sock_type
+ raise Exception('unsupported socket type')
+ for srv_sock in self.srv_socks:
+ if srv_sock.type == type:
+ srv_sock.client_addr = client_addr
+ return srv_sock.name
+ srv_sock = SInfo(type,address,port,client_addr)
+ srv_sock.thread = threading.Thread(target=self.query_io, args=(srv_sock,))
+ srv_sock.thread.start()
+ while srv_sock.active is False:
+ continue
+ self.srv_socks.append(srv_sock)
+ if TEST_DEBUG > 0:
+ syn_message("socket started")
+ return srv_sock.name
+
+ def client(self, dst_addr = None):
+ """ Return connected client. """
+ if dst_addr is not None:
+ dst_addr = dst_addr.split('@')[0]
+ sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ sockname = self.get_server_socket(dst_addr,socket.AF_INET)
+ sock.connect(sockname)
+ self.client_socks.append(sock)
+ return sock, sockname
def test_sendrecv():
""" Module self-test code. """
server = TestServer(None)
- client = server.client()
server.start()
+ client, peer = server.client()
try:
query = dns.message.make_query('.', dns.rdatatype.NS)
- send_message(client, query)
- answer = recv_message(client)
+ client.send(query.to_wire())
+ answer, _ = recvfrom_message(client)
if answer is None:
raise Exception('no answer received')
if not query.is_response(answer):
raise Exception('not a mirror response')
finally:
- client.close()
server.stop()
+ client.close()
if __name__ == '__main__':
sys.exit(1)
# Mirror server
- server = TestServer(None, socket.AF_INET, '127.0.0.1')
- print('[==========] Mirror server running at %s' % str(server.address()))
+ server = TestServer(None)
+ server.start()
+ server.get_server_socket(None, socket.AF_INET)
+ print('[==========] Mirror server running at', server.address())
try:
- server.run()
+ while True:
+ time.sleep(0.5)
except KeyboardInterrupt:
print('[==========] Shutdown.')
pass