}
static uint16_t
-checksum(const void *data, size_t len)
+in_cksum(void *data, size_t len, uint32_t *isum)
{
- const uint8_t *addr = data;
- uint32_t sum = 0;
+ const uint16_t *word = data;
+ uint32_t sum = isum != NULL ? *isum : 0;
- while (len > 1) {
- sum += (uint32_t)(addr[0] * 256 + addr[1]);
- addr += 2;
- len -= 2;
- }
+ for (; len > 1; len -= sizeof(*word))
+ sum += *word++;
if (len == 1)
- sum += (uint32_t)(*addr * 256);
+ sum += *(const uint8_t *)word;
+
+ if (isum != NULL)
+ *isum = sum;
sum = (sum >> 16) + (sum & 0xffff);
sum += (sum >> 16);
- return (uint16_t)~htons((uint16_t)sum);
+ return (uint16_t)~sum;
}
static struct bootp_pkt *
udp->uh_dport = htons(BOOTPS);
udp->uh_ulen = htons((uint16_t)(sizeof(*udp) + length));
ip->ip_len = udp->uh_ulen;
- udp->uh_sum = checksum(udpp, sizeof(*ip) + sizeof(*udp) + length);
+ udp->uh_sum = in_cksum(udpp, sizeof(*ip) + sizeof(*udp) + length, NULL);
ip->ip_v = IPVERSION;
ip->ip_hl = sizeof(*ip) >> 2;
ip->ip_id = (uint16_t)arc4random_uniform(UINT16_MAX);
ip->ip_ttl = IPDEFTTL;
ip->ip_len = htons((uint16_t)(sizeof(*ip) + sizeof(*udp) + length));
- ip->ip_sum = checksum(ip, sizeof(*ip));
+ ip->ip_sum = in_cksum(ip, sizeof(*ip), NULL);
*sz = sizeof(*ip) + sizeof(*udp) + length;
return udpp;
unsigned int flags)
{
struct ip *ip = packet;
- char ip_hlv = *(char *)ip;
+ struct ip pseudo_ip = {
+ .ip_p = IPPROTO_UDP,
+ .ip_src = ip->ip_src,
+ .ip_dst = ip->ip_dst
+ };
size_t ip_hlen;
uint16_t ip_len, uh_sum;
struct udphdr *udp;
+ uint32_t csum;
if (plen < sizeof(*ip)) {
if (from != NULL)
from->s_addr = ip->ip_src.s_addr;
ip_hlen = (size_t)ip->ip_hl * 4;
- if (checksum(ip, ip_hlen) != 0) {
+ if (in_cksum(ip, ip_hlen, NULL) != 0) {
errno = EINVAL;
return -1;
}
- ip_len = ntohs(ip->ip_len);
/* Check we have a payload */
+ ip_len = ntohs(ip->ip_len);
if (ip_len <= ip_hlen + sizeof(*udp)) {
errno = ERANGE;
return -1;
if (flags & BPF_PARTIALCSUM)
return 0;
+ /* UDP checksum is based on a pseudo IP header alongside
+ * the UDP header and payload. */
udp = (struct udphdr *)((char *)ip + ip_hlen);
if (udp->uh_sum == 0)
return 0;
- uh_sum = udp->uh_sum;
- /* This does scribble on the packet, but at this point
- * we don't care to keep it. */
+ uh_sum = udp->uh_sum;
udp->uh_sum = 0;
- ip->ip_hl = 0;
- ip->ip_v = 0;
- ip->ip_tos = 0;
- ip->ip_len = udp->uh_ulen;
- ip->ip_id = 0;
- ip->ip_off = 0;
- ip->ip_ttl = 0;
- ip->ip_sum = 0;
- if (checksum(packet, ip_len) != uh_sum) {
+ pseudo_ip.ip_len = udp->uh_ulen;
+ csum = 0;
+ in_cksum(&pseudo_ip, sizeof(pseudo_ip), &csum);
+ if (in_cksum(udp, ntohs(udp->uh_ulen), &csum) != uh_sum) {
errno = EINVAL;
return -1;
}
- *(char *)ip = ip_hlv;
- ip->ip_len = htons(ip_len);
return 0;
}