]> git.ipfire.org Git - thirdparty/dhcpcd.git/commitdiff
DHCP: Rework checksuming so that the packet isn't touched.
authorRoy Marples <roy@marples.name>
Wed, 31 Jul 2019 08:39:58 +0000 (09:39 +0100)
committerRoy Marples <roy@marples.name>
Wed, 31 Jul 2019 08:39:58 +0000 (09:39 +0100)
Other than setting udp->uh_sum to zero which we need to do to
calculate the checksum.
Also, the UDP checksum needs to include a pseudo IP header
without options and mostly blank. Instead of changing the packet,
just checksum a blank object we've filled in with the needed
data from the given IP object and use this to start the UDP
checksum calculation with.

While here, improve the checksum function so it more matches the
in_cksum function as noted in RFC 1071 4.1 using 16 byte words.

src/dhcp.c

index d6ff1e2dbde33b37877ec207b72bfdd5e76131a4..5965b9a97cc0f082c25b30b49ba9c3b1af51a037 100644 (file)
@@ -1584,24 +1584,24 @@ eexit:
 }
 
 static uint16_t
-checksum(const void *data, size_t len)
+in_cksum(void *data, size_t len, uint32_t *isum)
 {
-       const uint8_t *addr = data;
-       uint32_t sum = 0;
+       const uint16_t *word = data;
+       uint32_t sum = isum != NULL ? *isum : 0;
 
-       while (len > 1) {
-               sum += (uint32_t)(addr[0] * 256 + addr[1]);
-               addr += 2;
-               len -= 2;
-       }
+       for (; len > 1; len -= sizeof(*word))
+               sum += *word++;
 
        if (len == 1)
-               sum += (uint32_t)(*addr * 256);
+               sum += *(const uint8_t *)word;
+
+       if (isum != NULL)
+               *isum = sum;
 
        sum = (sum >> 16) + (sum & 0xffff);
        sum += (sum >> 16);
 
-       return (uint16_t)~htons((uint16_t)sum);
+       return (uint16_t)~sum;
 }
 
 static struct bootp_pkt *
@@ -1639,14 +1639,14 @@ dhcp_makeudppacket(size_t *sz, const uint8_t *data, size_t length,
        udp->uh_dport = htons(BOOTPS);
        udp->uh_ulen = htons((uint16_t)(sizeof(*udp) + length));
        ip->ip_len = udp->uh_ulen;
-       udp->uh_sum = checksum(udpp, sizeof(*ip) +  sizeof(*udp) + length);
+       udp->uh_sum = in_cksum(udpp, sizeof(*ip) + sizeof(*udp) + length, NULL);
 
        ip->ip_v = IPVERSION;
        ip->ip_hl = sizeof(*ip) >> 2;
        ip->ip_id = (uint16_t)arc4random_uniform(UINT16_MAX);
        ip->ip_ttl = IPDEFTTL;
        ip->ip_len = htons((uint16_t)(sizeof(*ip) + sizeof(*udp) + length));
-       ip->ip_sum = checksum(ip, sizeof(*ip));
+       ip->ip_sum = in_cksum(ip, sizeof(*ip), NULL);
 
        *sz = sizeof(*ip) + sizeof(*udp) + length;
        return udpp;
@@ -3236,10 +3236,15 @@ valid_udp_packet(void *packet, size_t plen, struct in_addr *from,
        unsigned int flags)
 {
        struct ip *ip = packet;
-       char ip_hlv = *(char *)ip;
+       struct ip pseudo_ip = {
+               .ip_p = IPPROTO_UDP,
+               .ip_src = ip->ip_src,
+               .ip_dst = ip->ip_dst
+       };
        size_t ip_hlen;
        uint16_t ip_len, uh_sum;
        struct udphdr *udp;
+       uint32_t csum;
 
        if (plen < sizeof(*ip)) {
                if (from != NULL)
@@ -3252,13 +3257,13 @@ valid_udp_packet(void *packet, size_t plen, struct in_addr *from,
                from->s_addr = ip->ip_src.s_addr;
 
        ip_hlen = (size_t)ip->ip_hl * 4;
-       if (checksum(ip, ip_hlen) != 0) {
+       if (in_cksum(ip, ip_hlen, NULL) != 0) {
                errno = EINVAL;
                return -1;
        }
 
-       ip_len = ntohs(ip->ip_len);
        /* Check we have a payload */
+       ip_len = ntohs(ip->ip_len);
        if (ip_len <= ip_hlen + sizeof(*udp)) {
                errno = ERANGE;
                return -1;
@@ -3272,28 +3277,21 @@ valid_udp_packet(void *packet, size_t plen, struct in_addr *from,
        if (flags & BPF_PARTIALCSUM)
                return 0;
 
+       /* UDP checksum is based on a pseudo IP header alongside
+        * the UDP header and payload. */
        udp = (struct udphdr *)((char *)ip + ip_hlen);
        if (udp->uh_sum == 0)
                return 0;
-       uh_sum = udp->uh_sum;
 
-       /* This does scribble on the packet, but at this point
-        * we don't care to keep it. */
+       uh_sum = udp->uh_sum;
        udp->uh_sum = 0;
-       ip->ip_hl = 0;
-       ip->ip_v = 0;
-       ip->ip_tos = 0;
-       ip->ip_len = udp->uh_ulen;
-       ip->ip_id = 0;
-       ip->ip_off = 0;
-       ip->ip_ttl = 0;
-       ip->ip_sum = 0;
-       if (checksum(packet, ip_len) != uh_sum) {
+       pseudo_ip.ip_len = udp->uh_ulen;
+       csum = 0;
+       in_cksum(&pseudo_ip, sizeof(pseudo_ip), &csum);
+       if (in_cksum(udp, ntohs(udp->uh_ulen), &csum) != uh_sum) {
                errno = EINVAL;
                return -1;
        }
-       *(char *)ip = ip_hlv;
-       ip->ip_len = htons(ip_len);
 
        return 0;
 }