struct keyaccess {
KeyAccT * next;
sockaddr_u addr;
- int subnetbits;
+ unsigned int subnetbits;
};
extern KeyAccT* keyacc_new_push(KeyAccT *head, const sockaddr_u *addr,
- int subnetbits);
+ unsigned int subnetbits);
extern KeyAccT* keyacc_pop_free(KeyAccT *head);
extern KeyAccT* keyacc_all_free(KeyAccT *head);
extern int keyacc_contains(const KeyAccT *head, const sockaddr_u *addr,
int res_on_empty_list);
+/* public for testability: */
+extern int keyacc_amatch(const sockaddr_u *,const sockaddr_u *,
+ unsigned int mbits);
+
#endif /* NTP_KEYACC_H */
keyacc_new_push(
KeyAccT * head,
const sockaddr_u * addr,
- int subnetbits
+ unsigned int subnetbits
)
{
KeyAccT * node = emalloc(sizeof(KeyAccT));
{
if (head) {
do {
- if (SOCK_EQ(&head->addr, addr))
+ if (keyacc_amatch(&head->addr, addr,
+ head->subnetbits))
return TRUE;
} while (NULL != (head = head->next));
return FALSE;
}
}
+#if CHAR_BIT != 8
+# error "don't know how to handle bytes with that bit size"
+#endif
+
+/* ----------------------------------------------------------------- */
+/* check two addresses for a match, taking a prefix length into account
+ * when doing the compare.
+ *
+ * The ISC lib contains a similar function with not entirely specified
+ * semantics, so it seemed somewhat cleaner to do this from scratch.
+ *
+ * Note: It *is* assumed that the addresses are stored in network byte
+ * order, that is, most significant byte first!
+ */
+int/*BOOL*/
+keyacc_amatch(
+ const sockaddr_u * a1,
+ const sockaddr_u * a2,
+ unsigned int mbits
+ )
+{
+ const uint8_t * pm1;
+ const uint8_t * pm2;
+ uint8_t msk;
+ unsigned int len;
+
+ if (AF(a1) != AF(a2))
+ return FALSE;
+
+ switch (AF(a1)) {
+ case AF_INET:
+ /* IPv4 is easy: clamp size, get byte pointers */
+ if (mbits > sizeof(NSRCADR(a1)) * 8)
+ mbits = sizeof(NSRCADR(a1)) * 8;
+ pm1 = (const void*)&NSRCADR(a1);
+ pm2 = (const void*)&NSRCADR(a2);
+ break;
+
+ case AF_INET6:
+ /* IPv6 is slightly different: Both scopes must match,
+ * too, before we even consider doing a match!
+ */
+ if ( ! SCOPE_EQ(a1, a2))
+ return FALSE;
+ if (mbits > sizeof(NSRCADR6(a1)) * 8)
+ mbits = sizeof(NSRCADR6(a1)) * 8;
+ pm1 = (const void*)&NSRCADR6(a1);
+ pm2 = (const void*)&NSRCADR6(a2);
+ break;
+
+ default:
+ /* don't know how to compare that!?! */
+ return FALSE;
+ }
+
+ /* split bit length into byte length and partial byte mask */
+ msk = 0xFFu ^ (0xFFu >> (mbits & 7));
+ len = mbits >> 3;
+
+ if (len && memcmp(pm1, pm2, len))
+ return FALSE;
+ if (msk && ((pm1[len] ^ pm2[len]) & msk))
+ return FALSE;
+
+ return TRUE;
+}
/*
* init_auth - initialize internal data
return (u_short)r;
}
+int/*BOOL*/
+ipaddr_match_masked(const sockaddr_u *,const sockaddr_u *,
+ unsigned int mbits);
+
static void
authcache_flush_id(
keyid_t id
while (tp) {
char *i;
char *snp; /* subnet text pointer */
- int snbits;
+ unsigned int snbits;
sockaddr_u addr;
i = strchr(tp, (int)',');
}
snp = strchr(tp, (int)'/');
if (snp) {
- unsigned u;
char *sp;
*snp++ = '\0';
- snbits = -1;
- u = 0;
+ snbits = 0;
sp = snp;
while (*sp != '\0') {
if (!isdigit((unsigned char)*sp))
break;
- if (u > 1000)
+ if (snbits > 1000)
break; /* overflow */
- u = (u << 3) + (u << 1);
- u += *sp++ - '0'; /* ascii dependent */
+ snbits = 10 * snbits + (*sp++ - '0'); /* ascii dependent */
}
if (*sp != '\0') {
log_maybe(&nerr,
goto nextip;
}
} else {
- snbits = -1;
+ snbits = UINT_MAX;
}
if (is_ip_address(tp, AF_UNSPEC, &addr)) {
/* Make sure that snbits is valid for addr */
- if ( snbits == -1
- || (snbits >= 0 &&
- ( (IS_IPV4(&addr) && snbits <= 32)
- || (IS_IPV6(&addr) && snbits <= 128)))) {
- next->keyacclist = keyacc_new_push(
- next->keyacclist, &addr, snbits);
- } else {
-
- log_maybe(&nerr,
- "authreadkeys: invalid IP address/subnet <%s/%s> for key %d",
+ if ((snbits < UINT_MAX) &&
+ ( (IS_IPV4(&addr) && snbits > 32) ||
+ (IS_IPV6(&addr) && snbits > 128))) {
+ log_maybe(NULL,
+ "authreadkeys: excessive subnet mask <%s/%s> for key %d",
tp, snp, keyno);
- }
+ }
+ next->keyacclist = keyacc_new_push(
+ next->keyacclist, &addr, snbits);
} else {
log_maybe(&nerr,
"authreadkeys: invalid IP address <%s> for key %d",
if (NULL != pdigest_len) {
#ifdef OPENSSL
- const EVP_MD * md = EVP_get_digestbynid(key_type);
-
+ md = EVP_get_digestbynid(key_type);
digest_len = (md) ? EVP_MD_size(md) : 0;
if (!md || digest_len <= 0) {