.ip_dst = ip->ip_dst
};
size_t ip_hlen;
- uint16_t ip_len, uh_sum;
+ uint16_t ip_len, udp_len, uh_sum;
struct udphdr *udp;
uint32_t csum;
errno = ERANGE;
return -1;
}
- /* Check we don't go beyond the payload */
+ /* Check IP doesn't go beyond the payload */
if (ip_len > plen) {
errno = ENOBUFS;
return -1;
}
- if (flags & BPF_PARTIALCSUM)
+ /* Check UDP doesn't go beyond the payload */
+ udp = (struct udphdr *)(void *)((char *)ip + ip_hlen);
+ udp_len = ntohs(udp->uh_ulen);
+ if (udp_len > plen - ip_hlen) {
+ errno = ENOBUFS;
+ return -1;
+ }
+
+ if (udp->uh_sum == 0 || flags & BPF_PARTIALCSUM)
return 0;
/* UDP checksum is based on a pseudo IP header alongside
* the UDP header and payload. */
- udp = (struct udphdr *)(void *)((char *)ip + ip_hlen);
- if (udp->uh_sum == 0)
- return 0;
-
uh_sum = udp->uh_sum;
udp->uh_sum = 0;
pseudo_ip.ip_len = udp->uh_ulen;
csum = 0;
in_cksum(&pseudo_ip, sizeof(pseudo_ip), &csum);
- csum = in_cksum(udp, ntohs(udp->uh_ulen), &csum);
+ csum = in_cksum(udp, udp_len, &csum);
if (csum != uh_sum) {
errno = EINVAL;
return -1;