}
#endif
+#define BPF_M_FHLEN 0
+#define BPF_M_IPHLEN 1
+#define BPF_M_IPLEN 2
+#define BPF_M_UDP 3
+#define BPF_M_UDPLEN 4
+
static const struct bpf_insn bpf_bootp_ether[] = {
/* Make sure this is an IP packet. */
BPF_STMT(BPF_LD + BPF_H + BPF_ABS,
/* Load frame header length into X. */
BPF_STMT(BPF_LDX + BPF_W + BPF_IMM, sizeof(struct ether_header)),
- /* Copy to M0. */
- BPF_STMT(BPF_STX, 0),
+ /* Copy frame header length to memory */
+ BPF_STMT(BPF_STX, BPF_M_FHLEN),
};
#define BPF_BOOTP_ETHER_LEN __arraycount(bpf_bootp_ether)
static const struct bpf_insn bpf_bootp_filter[] = {
- /* Make sure it's an optionless IPv4 packet. */
+ /* Make sure it's an IPv4 packet. */
+ BPF_STMT(BPF_LD + BPF_B + BPF_IND, 0),
+ BPF_STMT(BPF_ALU + BPF_AND + BPF_K, 0xf0),
+ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, 0x40, 1, 0),
+ BPF_STMT(BPF_RET + BPF_K, 0),
+
+ /* Ensure IP header length is big enough and
+ * store the IP header length in memory. */
BPF_STMT(BPF_LD + BPF_B + BPF_IND, 0),
- BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, 0x45, 1, 0),
+ BPF_STMT(BPF_ALU + BPF_AND + BPF_K, 0x0f),
+ BPF_STMT(BPF_ALU + BPF_MUL + BPF_K, 4),
+ BPF_JUMP(BPF_JMP + BPF_JGE + BPF_K, sizeof(struct ip), 1, 0),
BPF_STMT(BPF_RET + BPF_K, 0),
+ BPF_STMT(BPF_ST, BPF_M_IPHLEN),
/* Make sure it's a UDP packet. */
BPF_STMT(BPF_LD + BPF_B + BPF_IND, offsetof(struct ip, ip_p)),
BPF_JUMP(BPF_JMP + BPF_JSET + BPF_K, 0x1fff, 0, 1),
BPF_STMT(BPF_RET + BPF_K, 0),
- /* Store IP location in M1. */
- BPF_STMT(BPF_LD + BPF_H + BPF_IND, offsetof(struct ip, ip_len)),
- BPF_STMT(BPF_ST, 1),
-
- /* Store IP length in M2. */
+ /* Store IP length. */
BPF_STMT(BPF_LD + BPF_H + BPF_IND, offsetof(struct ip, ip_len)),
- BPF_STMT(BPF_ST, 2),
+ BPF_STMT(BPF_ST, BPF_M_IPLEN),
/* Advance to the UDP header. */
- BPF_STMT(BPF_MISC + BPF_TXA, 0),
- BPF_STMT(BPF_ALU + BPF_ADD + BPF_K, sizeof(struct ip)),
+ BPF_STMT(BPF_LD + BPF_MEM, BPF_M_IPHLEN),
+ BPF_STMT(BPF_ALU + BPF_ADD + BPF_X, 0),
BPF_STMT(BPF_MISC + BPF_TAX, 0),
- /* Store X in M3. */
- BPF_STMT(BPF_STX, 3),
+ /* Store UDP location */
+ BPF_STMT(BPF_STX, BPF_M_UDP),
/* Make sure it's from and to the right port. */
BPF_STMT(BPF_LD + BPF_W + BPF_IND, 0),
BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, (BOOTPS << 16) + BOOTPC, 1, 0),
BPF_STMT(BPF_RET + BPF_K, 0),
- /* Store UDP length in X. */
+ /* Store UDP length. */
BPF_STMT(BPF_LD + BPF_H + BPF_IND, offsetof(struct udphdr, uh_ulen)),
+ BPF_STMT(BPF_ST, BPF_M_UDPLEN),
+
+ /* Ensure that UDP length + IP header length == IP length */
+ /* Copy IP header length to X. */
+ BPF_STMT(BPF_LDX + BPF_MEM, BPF_M_IPHLEN),
+ /* Add UDP length (A) to IP header length (X). */
+ BPF_STMT(BPF_ALU + BPF_ADD + BPF_X, 0),
+ /* Store result in X. */
BPF_STMT(BPF_MISC + BPF_TAX, 0),
- /* Copy IP length in M2 to A. */
- BPF_STMT(BPF_LD + BPF_MEM, 2),
- /* Ensure IP length - IP header size == UDP length. */
- BPF_STMT(BPF_ALU + BPF_SUB + BPF_K, sizeof(struct ip)),
+ /* Copy IP length to A. */
+ BPF_STMT(BPF_LD + BPF_MEM, BPF_M_IPLEN),
+ /* Ensure X == A. */
BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_X, 0, 1, 0),
BPF_STMT(BPF_RET + BPF_K, 0),
- /* Advance to the BOOTP packet (UDP X is in M3). */
- BPF_STMT(BPF_LD + BPF_MEM, 3),
+ /* Advance to the BOOTP packet. */
+ BPF_STMT(BPF_LD + BPF_MEM, BPF_M_UDP),
BPF_STMT(BPF_ALU + BPF_ADD + BPF_K, sizeof(struct udphdr)),
BPF_STMT(BPF_MISC + BPF_TAX, 0),
}
#endif
- /* All passed, return the packet
- * (Frame length in M0, IP length in M2). */
- BPF_SET_STMT(bp, BPF_LD + BPF_MEM, 0);
+ /* All passed, return the packet - frame length + ip length */
+ BPF_SET_STMT(bp, BPF_LD + BPF_MEM, BPF_M_FHLEN);
bp++;
- BPF_SET_STMT(bp, BPF_LDX + BPF_MEM, 2);
+ BPF_SET_STMT(bp, BPF_LDX + BPF_MEM, BPF_M_IPLEN);
bp++;
BPF_SET_STMT(bp, BPF_ALU + BPF_ADD + BPF_X, 0);
bp++;
}
static void *
-get_udp_data(void *udp, size_t *len)
+get_udp_data(void *packet, size_t *len)
{
- struct bootp_pkt *p;
+ const struct ip *ip = packet;
+ size_t ip_hl = (size_t)ip->ip_hl * 4;
+ char *p = packet;
- p = (struct bootp_pkt *)udp;
- *len = (size_t)ntohs(p->ip.ip_len) - sizeof(p->ip) - sizeof(p->udp);
- return (char *)udp + offsetof(struct bootp_pkt, bootp);
+ p += ip_hl + sizeof(struct udphdr);
+ *len = (size_t)ntohs(ip->ip_len) - sizeof(struct udphdr) - ip_hl;
+ return p;
}
static int
-valid_udp_packet(void *data, size_t data_len, struct in_addr *from,
- int noudpcsum)
+valid_udp_packet(void *packet, size_t plen, struct in_addr *from,
+ unsigned int flags)
{
- struct bootp_pkt *p;
- uint16_t bytes;
+ struct ip *ip = packet;
+ char ip_hlv = *(char *)ip;
+ size_t ip_hlen;
+ uint16_t ip_len, uh_sum;
+ struct udphdr *udp;
- if (data_len < sizeof(p->ip)) {
- if (from)
+ if (plen < sizeof(*ip)) {
+ if (from != NULL)
from->s_addr = INADDR_ANY;
errno = ERANGE;
return -1;
}
- p = (struct bootp_pkt *)data;
- if (from)
- from->s_addr = p->ip.ip_src.s_addr;
- if (checksum(&p->ip, sizeof(p->ip)) != 0) {
+
+ if (from != NULL)
+ from->s_addr = ip->ip_src.s_addr;
+
+ ip_hlen = (size_t)ip->ip_hl * 4;
+ if (checksum(ip, ip_hlen) != 0) {
errno = EINVAL;
return -1;
}
- bytes = ntohs(p->ip.ip_len);
+ ip_len = ntohs(ip->ip_len);
/* Check we have a payload */
- if (bytes <= sizeof(p->ip) + sizeof(p->udp)) {
+ if (ip_len <= ip_hlen + sizeof(*udp)) {
errno = ERANGE;
return -1;
}
/* Check we don't go beyond the payload */
- if (bytes > data_len) {
+ if (ip_len > plen) {
errno = ENOBUFS;
return -1;
}
- if (noudpcsum == 0) {
- uint16_t udpsum, iplen;
-
- /* This does scribble on the packet, but at this point
- * we don't care to keep it. */
- iplen = p->ip.ip_len;
- udpsum = p->udp.uh_sum;
- p->udp.uh_sum = 0;
- p->ip.ip_hl = 0;
- p->ip.ip_v = 0;
- p->ip.ip_tos = 0;
- p->ip.ip_len = p->udp.uh_ulen;
- p->ip.ip_id = 0;
- p->ip.ip_off = 0;
- p->ip.ip_ttl = 0;
- p->ip.ip_sum = 0;
- if (udpsum && checksum(p, bytes) != udpsum) {
- errno = EINVAL;
- return -1;
- }
- p->ip.ip_len = iplen;
+ if (flags & BPF_PARTIALCSUM)
+ return 0;
+
+ udp = (struct udphdr *)((char *)ip + ip_hlen);
+ if (udp->uh_sum == 0)
+ return 0;
+ uh_sum = udp->uh_sum;
+
+ /* This does scribble on the packet, but at this point
+ * we don't care to keep it. */
+ udp->uh_sum = 0;
+ ip->ip_hl = 0;
+ ip->ip_v = 0;
+ ip->ip_tos = 0;
+ ip->ip_len = udp->uh_ulen;
+ ip->ip_id = 0;
+ ip->ip_off = 0;
+ ip->ip_ttl = 0;
+ ip->ip_sum = 0;
+ if (checksum(packet, ip_len) != uh_sum) {
+ errno = EINVAL;
+ return -1;
}
+ *(char *)ip = ip_hlv;
+ ip->ip_len = htons(ip_len);
return 0;
}
size_t udp_len;
const struct dhcp_state *state = D_CSTATE(ifp);
- if (valid_udp_packet(data, len, &from,
- state->bpf_flags & RAW_PARTIALCSUM) == -1)
- {
+ if (valid_udp_packet(data, len, &from, state->bpf_flags) == -1) {
if (errno == EINVAL)
logerrx("%s: checksum failure from %s",
ifp->name, inet_ntoa(from));