req->requested_name = name;
}
+static guint16
+rspamd_bind_to_random_port (int sock)
+{
+ union sa_union su;
+ socklen_t slen = sizeof (su);
+ guint16 ret = 0;
+ const int max_retries = 10;
+ int retries = 0;
+
+ if (getsockname (sock, &su.sa, &slen) != -1) {
+
+ while (retries < max_retries) {
+ ret = g_random_int_range (1024, G_MAXUINT16 - 1);
+ if (su.sa.sa_family == AF_INET) {
+ su.s4.sin_port = htons (ret);
+ }
+ else if (su.sa.sa_family == AF_INET6) {
+ su.s6.sin6_port = htons (ret);
+ }
+ if (bind (sock, &su.sa, slen) != -1) {
+ return ret;
+ }
+ retries ++;
+ }
+ }
+
+ return 0;
+}
+
static gint
send_dns_request (struct rspamd_dns_request *req)
{
gint r;
+ req->port = rspamd_bind_to_random_port (req->sock);
+ req->key = ((guint32)req->port) << 16 + req->id;
r = send (req->sock, req->packet, req->pos, 0);
if (r == -1) {
if (errno == EAGAIN) {
struct rspamd_dns_request *req = arg;
event_del (&req->timer_event);
- g_hash_table_remove (req->resolver->requests, &req->id);
+ g_hash_table_remove (req->resolver->requests, &req->key);
}
static guint8 *
static gboolean
dns_parse_reply (guint8 *in, gint r, struct rspamd_dns_resolver *resolver,
- struct rspamd_dns_request **req_out, struct rspamd_dns_reply **_rep)
+ guint16 port, struct rspamd_dns_request **req_out,
+ struct rspamd_dns_reply **_rep)
{
struct dns_header *header = (struct dns_header *)in;
struct rspamd_dns_request *req;
union rspamd_reply_element *elt;
guint8 *pos;
guint16 id;
+ guint32 key;
gint i, t;
/* First check header fields */
/* Now try to find corresponding request */
id = header->qid;
- if ((req = g_hash_table_lookup (resolver->requests, &id)) == NULL) {
+ key = ((guint32)port) << 16 + id;
+ if ((req = g_hash_table_lookup (resolver->requests, &key)) == NULL) {
/* No such requests found */
return FALSE;
}
gint r;
struct rspamd_dns_reply *rep;
guint8 in[UDP_PACKET_SIZE];
+ union sa_union su;
+ socklen_t slen = sizeof (su);
+ guint16 port = 0;
/* This function is called each time when we have data on one of server's sockets */
/* First read packet from socket */
- r = read (fd, in, sizeof (in));
+ r = recvfrom (fd, in, sizeof (in), 0, &su.sa, &slen);
if (r > (gint)(sizeof (struct dns_header) + sizeof (struct dns_query))) {
- if (dns_parse_reply (in, r, resolver, &req, &rep)) {
+ if (su.sa.sa_family == AF_INET) {
+ port = ntohs (su.s4.sin_port);
+ }
+ else if (su.sa.sa_family == AF_INET6) {
+ port = ntohs (su.s6.sin6_port);
+ }
+ if (dns_parse_reply (in, r, resolver, port, &req, &rep)) {
/* Decrease errors count */
if (rep->request->resolver->errors > 0) {
rep->request->resolver->errors --;
}
if (req->server->sock == -1) {
- req->server->sock = make_universal_socket (req->server->name, dns_port, SOCK_DGRAM, TRUE, FALSE, FALSE);
+ req->server->sock = make_universal_socket (req->server->name,
+ dns_port, SOCK_DGRAM, TRUE, FALSE, FALSE);
}
req->sock = req->server->sock;
evtimer_add (&req->timer_event, &req->tv);
/* Add request to hash table */
- g_hash_table_insert (req->resolver->requests, &req->id, req);
+ g_hash_table_insert (req->resolver->requests, &req->key, req);
register_async_event (req->session, (event_finalizer_t)dns_fin_cb, req, g_quark_from_static_string ("dns resolver"));
}
}
}
if (req->server->sock == -1) {
- req->server->sock = make_universal_socket (req->server->name, dns_port, SOCK_DGRAM, TRUE, FALSE, FALSE);
+ req->server->sock = make_universal_socket (req->server->name,
+ dns_port, SOCK_DGRAM, TRUE, FALSE, FALSE);
}
req->sock = req->server->sock;
evtimer_add (&req->timer_event, &req->tv);
/* Add request to hash table */
- while (g_hash_table_lookup (resolver->requests, &req->id)) {
+ while (g_hash_table_lookup (resolver->requests, &req->key)) {
/* Check for unique id */
header = (struct dns_header *)req->packet;
header->qid = dns_k_permutor_step (resolver->permutor);
req->id = header->qid;
}
- g_hash_table_insert (resolver->requests, &req->id, req);
+ g_hash_table_insert (resolver->requests, &req->key, req);
register_async_event (session, (event_finalizer_t)dns_fin_cb, req, g_quark_from_static_string ("dns resolver"));
}
else if (r == -1) {
static gboolean
dns_id_equal (gconstpointer v1, gconstpointer v2)
{
- return *((const guint16*) v1) == *((const guint16*) v2);
+ return *((const guint32*) v1) == *((const guint32*) v2);
}
static guint
dns_id_hash (gconstpointer v)
{
- return *(const guint16 *) v;
+ return *(const guint32 *) v;
}
/* Now init all servers */
for (i = 0; i < new->servers_num; i ++) {
serv = &new->servers[i];
- serv->sock = make_universal_socket (serv->name, dns_port, SOCK_DGRAM, TRUE, FALSE, FALSE);
+ serv->sock = make_universal_socket (serv->name, dns_port,
+ SOCK_DGRAM, TRUE, FALSE, FALSE);
if (serv->sock == -1) {
msg_warn ("cannot create socket to server %s", serv->name);
}