]> git.ipfire.org Git - thirdparty/mtr.git/commitdiff
allow setting local and remote port for UDP probing
authorrussor <russor@whatsapp.com>
Thu, 25 Feb 2016 22:31:58 +0000 (14:31 -0800)
committerrussor <russor@whatsapp.com>
Fri, 8 Apr 2016 23:59:27 +0000 (16:59 -0700)
This allows for testing over a single network path in a load balanced
network where flows are hashed on the ips and ports (relatively common), the
default behavior is to use the destination port as the sequence number when
doing UDP probes, so you see (more or less) an average of all the paths.

If the remote port is set, we need to use the checksum field to store
sequence and adjust the paylod so the checksum is valid; therefore we need
to have the source address, as it's part of the checksum calculation.  We
also need to increase the minimum packet size so we have at least 2 bytes of
payload.

mtr.8
mtr.c
net.c

diff --git a/mtr.8 b/mtr.8
index 06df650a476ae447f4e1859dbc51544f53697fb6..e225b3b427b3ac9a949a60ec92d736c946fab510 100644 (file)
--- a/mtr.8
+++ b/mtr.8
@@ -89,6 +89,9 @@ mtr \- a network diagnostic tool
 .BI \-P \ PORT\c
 ]
 [\c
+.BI \-L \ LOCALPORT\c
+]
+[\c
 .BI \-Z \ TIMEOUT\c
 ]
 [\c
@@ -353,7 +356,10 @@ Use TCP SYN packets instead of ICMP ECHO.
 is ignored, since SYN packets can not contain data.
 .TP
 .B \-P \fIPORT\fR, \fB\-\-port \fIPORT
-The target port number for TCP traces.
+The target port number for TCP/SCTP/UDP traces.
+.TP
+.B \-L \fILOCALPORT\fR, \fB\-\-localport \fILOCALPORT
+The source port number for UDP traces.
 .TP
 .B \-Z \fISECONDS\fR, \fB\-\-timeout \fISECONDS
 The number of seconds to keep the TCP socket open before giving up on
diff --git a/mtr.c b/mtr.c
index a66d4ad7beeadf23265b531b88d0e9c70bfc22c6..4fe62b3cee680526d5153bfc7d7dc2580913c65e 100644 (file)
--- a/mtr.c
+++ b/mtr.c
@@ -88,7 +88,8 @@ int  fstTTL = 1;                /* default start at first hop */
 /*int maxTTL = MaxHost-1;  */     /* max you can go is 255 hops */
 int   maxTTL = 30;              /* inline with traceroute */
                                 /* end ttl window stuff. */
-int remoteport = 80;            /* for TCP tracing */
+int remoteport = 0;            /* for TCP tracing */
+int localport = 0;             /* for UDP tracing */
 int tcp_timeout = 10 * 1000000;     /* for TCP tracing */
 
 
@@ -304,7 +305,8 @@ void parse_arg (int argc, char **argv)
     { "udp", 0, 0, 'u' },      /* UDP (default is ICMP) */
     { "tcp", 0, 0, 'T' },      /* TCP (default is ICMP) */
     { "sctp", 0, 0, 'S' },     /* SCTP (default is ICMP) */
-    { "port", 1, 0, 'P' },      /* target port number for TCP */
+    { "port", 1, 0, 'P' },      /* target port number for TCP/SCTP/UDP */
+    { "localport", 1, 0, 'L' }, /* source port number for UDP */
     { "timeout", 1, 0, 'Z' },   /* timeout for TCP sockets */
 #ifdef SO_MARK
     { "mark", 1, 0, 'M' },      /* use SO_MARK */
@@ -315,7 +317,7 @@ void parse_arg (int argc, char **argv)
   opt = 0;
   while(1) {
     opt = getopt_long(argc, argv,
-                     "hv46F:rwxtglCpnbo:y:zi:c:s:B:Q:ea:f:m:uTSP:Z:M:", long_options, NULL);
+                     "hv46F:rwxtglCpnbo:y:zi:c:s:B:Q:ea:f:m:uTSP:L:Z:M:", long_options, NULL);
     if(opt == -1)
       break;
 
@@ -446,6 +448,9 @@ void parse_arg (int argc, char **argv)
         fprintf(stderr, "-u , -T and -S are mutually exclusive.\n");
         exit(EXIT_FAILURE);
       }
+      if (!remoteport) {
+        remoteport = 80;
+      }
       mtrtype = IPPROTO_TCP;
       break;
     case 'S':
@@ -453,6 +458,9 @@ void parse_arg (int argc, char **argv)
         fprintf(stderr, "-u , -T and -S are mutually exclusive.\n");
         exit(EXIT_FAILURE);
       }
+      if (!remoteport) {
+        remoteport = 80;
+      }
       mtrtype = IPPROTO_SCTP;
     case 'b':
       show_ips = 1;
@@ -464,6 +472,13 @@ void parse_arg (int argc, char **argv)
         exit(EXIT_FAILURE);
       }
       break;
+    case 'L':
+      localport = atoi(optarg);
+      if (localport > 65535 || localport < MinPort) {
+        fprintf(stderr, "Illegal local port number.\n");
+        exit(EXIT_FAILURE);
+      }
+      break;
     case 'Z':
       tcp_timeout = atoi(optarg);
       tcp_timeout *= 1000000;
@@ -616,12 +631,17 @@ int main(int argc, char **argv)
               "\t\t[-i INTERVAL] [-c COUNT] [-s PACKETSIZE] [-B BITPATTERN]\n"
               "\t\t[-Q TOS] [--mpls]\n"
               "\t\t[-a ADDRESS] [-f FIRST-TTL] [-m MAX-TTL]\n"
-              "\t\t[--udp] [--tcp] [--sctp] [-P PORT] [-Z TIMEOUT]\n"
+              "\t\t[--udp] [--tcp] [--sctp] [-P PORT] [-L LOCALPORT] [-Z TIMEOUT]\n"
               "\t\t[-M MARK] HOSTNAME\n", argv[0]);
        printf("See the man page for details.\n");
     exit(0);
   }
 
+  if (mtrtype == IPPROTO_UDP && remoteport && !InterfaceAddress) {
+    fprintf(stderr, "mtr: -a ADDRESS required in udp mode (-u) with remote port set (-P).\n");
+    exit(EXIT_FAILURE);
+  }
+
   time_t now = time(NULL);
 
   if (!names) append_to_names (argv[0], "localhost"); // default: localhost. 
@@ -720,6 +740,7 @@ int main(int argc, char **argv)
       }
     }
 
+
     lock(argv[0], stdout);
       display_open();
       dns_open();
diff --git a/net.c b/net.c
index e5f9322c7dddb60573a29b73e301f3bfc3b01575..48068250865c9b9842a767e10a06bceacb131333 100644 (file)
--- a/net.c
+++ b/net.c
@@ -216,6 +216,7 @@ extern int tos;                     /* type of service set in ping packet*/
 extern int af;                 /* address family of remote target */
 extern int mtrtype;            /* type of query packet used */
 extern int remoteport;          /* target port for TCP tracing */
+extern int localport;  /* source port for UDP tracing */
 extern int tcp_timeout;             /* timeout for TCP connections */
 #ifdef SO_MARK
 extern int mark;               /* SO_MARK to set for ping packet*/
@@ -251,11 +252,14 @@ int checksum(void *data, int sz)
 
 
 /* Prepend pseudoheader to the udp datagram and calculate checksum */
-int udp_checksum(void *pheader, void *udata, int psize, int dsize)
+int udp_checksum(void *pheader, void *udata, int psize, int dsize, int alt_checksum)
 {
   unsigned int tsize = psize + dsize;
   char csumpacket[tsize];
   memset(csumpacket, (unsigned char) abs(bitpattern), abs(tsize));
+  if (alt_checksum && dsize >= 2) {
+    csumpacket[psize] = csumpacket[psize + 1] = 0;
+  }
 
   struct UDPv4PHeader *prepend = (struct UDPv4PHeader *) csumpacket;
   struct UDPv4PHeader *udppheader = (struct UDPv4PHeader *) pheader;
@@ -559,6 +563,7 @@ void net_send_query(int index)
   struct ICMPHeader *icmp = NULL;
   struct UDPHeader *udp = NULL;
   struct UDPv4PHeader *udpp = NULL;
+  uint16 checksum_result;
   uint16 mypid;
 
   /*ok  int packetsize = sizeof(struct IPHeader) + sizeof(struct ICMPHeader) + datasize;*/
@@ -575,6 +580,9 @@ void net_send_query(int index)
 
   if ( packetsize < MINPACKET ) packetsize = MINPACKET;
   if ( packetsize > MAXPACKET ) packetsize = MAXPACKET;
+  if ( mtrtype == IPPROTO_UDP && remoteport && packetsize < (MINPACKET + 2)) {
+    packetsize = MINPACKET + 2;
+  }
 
   memset(packet, (unsigned char) abs(bitpattern), abs(packetsize));
 
@@ -646,16 +654,27 @@ void net_send_query(int index)
   case IPPROTO_UDP:
     udp = (struct UDPHeader *)(packet + iphsize);
     udp->checksum  = 0;
-    mypid = (uint16)getpid();
-    if (mypid < MinPort)
-      mypid += MinPort;
-
+    if (!localport) {
+      mypid = (uint16)getpid();
+      if (mypid < MinPort)
+        mypid += MinPort;
+    } else {
+      mypid = (uint16)localport;
+    }
     udp->srcport = htons(mypid);
     udp->length = htons(abs(packetsize) - iphsize);
 
-    udp->dstport = new_sequence(index);
-    gettimeofday(&sequence[udp->dstport].time, NULL);
-    udp->dstport = htons(udp->dstport);
+    if (!remoteport) {
+      udp->dstport = new_sequence(index);
+      gettimeofday(&sequence[udp->dstport].time, NULL);
+      udp->dstport = htons(udp->dstport);
+    } else {
+      // keep dstport constant, stuff sequence into the checksum
+      udp->dstport = htons(remoteport);
+      udp->checksum = new_sequence(index);
+      gettimeofday(&sequence[udp->checksum].time, NULL);
+      udp->checksum = htons(udp->checksum);
+    }
     break;
   }
 
@@ -664,13 +683,22 @@ void net_send_query(int index)
     switch ( mtrtype ) {
     case IPPROTO_UDP:
       /* checksum is not mandatory. only calculate if we know ip->saddr */
-      if (ip->saddr) {
+      if (udp->checksum) {
         udpp = (struct UDPv4PHeader *)(malloc(sizeof(struct UDPv4PHeader)));
         udpp->saddr = ip->saddr;
         udpp->daddr = ip->daddr;
         udpp->protocol = ip->protocol;
         udpp->len = udp->length;
-        udp->checksum = udp_checksum(udpp, udp, sizeof(struct UDPv4PHeader), abs(packetsize) - iphsize);
+        checksum_result = udp_checksum(udpp, udp, sizeof(struct UDPv4PHeader), abs(packetsize) - iphsize, 1);
+        packet[iphsize + sizeof(struct UDPHeader)] = checksum_result & 0xff;
+        packet[iphsize + sizeof(struct UDPHeader) + 1] = checksum_result >> 8;
+      } else if (ip->saddr) {
+        udpp = (struct UDPv4PHeader *)(malloc(sizeof(struct UDPv4PHeader)));
+        udpp->saddr = ip->saddr;
+        udpp->daddr = ip->daddr;
+        udpp->protocol = ip->protocol;
+        udpp->len = udp->length;
+        udp->checksum = udp_checksum(udpp, udp, sizeof(struct UDPv4PHeader), abs(packetsize) - iphsize, 0);
       }
       break;
     }
@@ -682,6 +710,9 @@ void net_send_query(int index)
     switch ( mtrtype ) {
     case IPPROTO_UDP:
       /* kernel checksum calculation */
+      if (udp->checksum) {
+        offset = sizeof(struct UDPHeader);
+      }
       if ( setsockopt(sendsock, IPPROTO_IPV6, IPV6_CHECKSUM, &offset, sizeof(offset)) ) {
         perror( "setsockopt IPV6_CHECKSUM" );
         exit( EXIT_FAILURE);
@@ -967,7 +998,11 @@ void net_process_return(void)
         break;
 #endif
       }
-      sequence = ntohs(udpheader->dstport);
+      if (remoteport && remoteport == ntohs(udpheader->dstport)) {
+        sequence = ntohs(udpheader->checksum);
+      } else if (!remoteport) {
+        sequence = ntohs(udpheader->dstport);
+      }
     }
     break;