]> git.ipfire.org Git - thirdparty/bird.git/commitdiff
Babel: Refactor TLV parsing code for easier reuse
authorToke Høiland-Jørgensen <toke@toke.dk>
Thu, 15 Apr 2021 18:15:53 +0000 (20:15 +0200)
committerOndrej Zajicek (work) <santiago@crfreenet.org>
Sun, 6 Jun 2021 14:28:18 +0000 (16:28 +0200)
In preparation for adding authentication checks, refactor the TLV
walking code so it can be reused for a separate pass of the packet
for authentication checks.

proto/babel/packets.c

index 415ac3f9c50006376b67dc13439f2c47443afe8b..1d2f5f5bdf3c64a3d88a15ad8ab52276e28b16b5 100644 (file)
@@ -120,8 +120,19 @@ struct babel_subtlv_source_prefix {
 #define BABEL_UF_DEF_PREFIX    0x80
 #define BABEL_UF_ROUTER_ID     0x40
 
+struct babel_parse_state;
+struct babel_write_state;
+
+struct babel_tlv_data {
+  u8 min_length;
+  int (*read_tlv)(struct babel_tlv *hdr, union babel_msg *m, struct babel_parse_state *state);
+  uint (*write_tlv)(struct babel_tlv *hdr, union babel_msg *m, struct babel_write_state *state, uint max_len);
+  void (*handle_tlv)(union babel_msg *m, struct babel_iface *ifa);
+};
 
 struct babel_parse_state {
+  const struct babel_tlv_data* (*get_tlv_data)(u8 type);
+  const struct babel_tlv_data* (*get_subtlv_data)(u8 type);
   struct babel_proto *proto;
   struct babel_iface *ifa;
   ip_addr saddr;
@@ -167,6 +178,37 @@ struct babel_write_state {
 
 #define NET_SIZE(n) BYTES(net_pxlen(n))
 
+
+/* Helper macros to loop over a series of TLVs.
+ * @start      pointer to first TLV (void * or struct babel_tlv *)
+ * @end        byte * pointer to TLV stream end
+ * @tlv        struct babel_tlv pointer used as iterator
+ * @frame_err  boolean (u8) that will be set to 1 if a frame error occurred
+ * @saddr      source addr for use in log output
+ * @ifname     ifname for use in log output
+ */
+#define WALK_TLVS(start, end, tlv, frame_err, saddr, ifname)    \
+  for (tlv = start;                                            \
+       (byte *)tlv < end;                                               \
+       tlv = NEXT_TLV(tlv))                                            \
+  {                                                                    \
+    byte *loop_pos;                                                    \
+    /* Ugly special case */                                            \
+    if (tlv->type == BABEL_TLV_PAD1)                                   \
+      continue;                                                         \
+                                                                       \
+    /* The end of the common TLV header */                             \
+    loop_pos = (byte *)tlv + sizeof(struct babel_tlv);                 \
+    if ((loop_pos > end) || (loop_pos + tlv->length > end))             \
+    {                                                                   \
+      LOG_PKT("Bad TLV from %I via %s type %d pos %d - framing error",  \
+             saddr, ifname, tlv->type, (byte *)tlv - (byte *)start);   \
+      frame_err = 1;                                                    \
+      break;                                                            \
+    }
+
+#define WALK_TLVS_END }
+
 static inline uint
 bytes_equal(u8 *b1, u8 *b2, uint maxlen)
 {
@@ -255,13 +297,6 @@ static uint babel_write_route_request(struct babel_tlv *hdr, union babel_msg *ms
 static uint babel_write_seqno_request(struct babel_tlv *hdr, union babel_msg *msg, struct babel_write_state *state, uint max_len);
 static int babel_write_source_prefix(struct babel_tlv *hdr, net_addr *net, uint max_len);
 
-struct babel_tlv_data {
-  u8 min_length;
-  int (*read_tlv)(struct babel_tlv *hdr, union babel_msg *m, struct babel_parse_state *state);
-  uint (*write_tlv)(struct babel_tlv *hdr, union babel_msg *m, struct babel_write_state *state, uint max_len);
-  void (*handle_tlv)(union babel_msg *m, struct babel_iface *ifa);
-};
-
 static const struct babel_tlv_data tlv_data[BABEL_TLV_MAX] = {
   [BABEL_TLV_ACK_REQ] = {
     sizeof(struct babel_tlv_ack_req),
@@ -319,6 +354,30 @@ static const struct babel_tlv_data tlv_data[BABEL_TLV_MAX] = {
   },
 };
 
+static const struct babel_tlv_data *get_packet_tlv_data(u8 type)
+{
+  return type < sizeof(tlv_data) / sizeof(*tlv_data) ? &tlv_data[type] : NULL;
+}
+
+static const struct babel_tlv_data source_prefix_tlv_data = {
+  sizeof(struct babel_subtlv_source_prefix),
+  babel_read_source_prefix,
+  NULL,
+  NULL
+};
+
+static const struct babel_tlv_data *get_packet_subtlv_data(u8 type)
+{
+  switch(type)
+  {
+  case BABEL_SUBTLV_SOURCE_PREFIX:
+    return &source_prefix_tlv_data;
+
+  default:
+    return NULL;
+  }
+}
+
 static int
 babel_read_ack_req(struct babel_tlv *hdr, union babel_msg *m,
                   struct babel_parse_state *state)
@@ -1083,69 +1142,65 @@ babel_write_source_prefix(struct babel_tlv *hdr, net_addr *n, uint max_len)
   return len;
 }
 
-
 static inline int
 babel_read_subtlvs(struct babel_tlv *hdr,
                   union babel_msg *msg,
                   struct babel_parse_state *state)
 {
+  const struct babel_tlv_data *tlv_data;
+  struct babel_proto *p = state->proto;
   struct babel_tlv *tlv;
-  byte *pos, *end = (byte *) hdr + TLV_LENGTH(hdr);
+  byte *end = (byte *) hdr + TLV_LENGTH(hdr);
+  u8 frame_err = 0;
   int res;
 
-  for (tlv = (void *) hdr + state->current_tlv_endpos;
-       (byte *) tlv < end;
-       tlv = NEXT_TLV(tlv))
+  WALK_TLVS((void *)hdr + state->current_tlv_endpos, end, tlv, frame_err,
+            state->saddr, state->ifa->ifname)
   {
-    /* Ugly special case */
-    if (tlv->type == BABEL_TLV_PAD1)
+    if (tlv->type == BABEL_SUBTLV_PADN)
       continue;
 
-    /* The end of the common TLV header */
-    pos = (byte *)tlv + sizeof(struct babel_tlv);
-    if ((pos > end) || (pos + tlv->length > end))
-      return PARSE_ERROR;
-
-    /*
-     * The subtlv type space is non-contiguous (due to the mandatory bit), so
-     * use a switch for dispatch instead of the mapping array we use for TLVs
-     */
-    switch (tlv->type)
+    if (!state->get_subtlv_data ||
+        !(tlv_data = state->get_subtlv_data(tlv->type)) ||
+        !tlv_data->read_tlv)
     {
-    case BABEL_SUBTLV_SOURCE_PREFIX:
-      res = babel_read_source_prefix(tlv, msg, state);
-      if (res != PARSE_SUCCESS)
-       return res;
-      break;
-
-    case BABEL_SUBTLV_PADN:
-    default:
       /* Unknown mandatory subtlv; PARSE_IGNORE ignores the whole TLV */
       if (tlv->type >= 128)
-       return PARSE_IGNORE;
-      break;
+        return PARSE_IGNORE;
+      continue;
     }
+
+    res = tlv_data->read_tlv(tlv, msg, state);
+    if (res != PARSE_SUCCESS)
+      return res;
   }
+  WALK_TLVS_END;
 
-  return PARSE_SUCCESS;
+  return frame_err ? PARSE_ERROR : PARSE_SUCCESS;
 }
 
-static inline int
+static int
 babel_read_tlv(struct babel_tlv *hdr,
                union babel_msg *msg,
                struct babel_parse_state *state)
 {
+  const struct babel_tlv_data *tlv_data;
+
   if ((hdr->type <= BABEL_TLV_PADN) ||
-      (hdr->type >= BABEL_TLV_MAX) ||
-      !tlv_data[hdr->type].read_tlv)
+      (hdr->type >= BABEL_TLV_MAX))
     return PARSE_IGNORE;
 
-  if (TLV_LENGTH(hdr) < tlv_data[hdr->type].min_length)
+  tlv_data = state->get_tlv_data(hdr->type);
+
+  if (!tlv_data || !tlv_data->read_tlv)
+    return PARSE_IGNORE;
+
+  if (TLV_LENGTH(hdr) < tlv_data->min_length)
     return PARSE_ERROR;
 
-  state->current_tlv_endpos = tlv_data[hdr->type].min_length;
+  state->current_tlv_endpos = tlv_data->min_length;
 
-  int res = tlv_data[hdr->type].read_tlv(hdr, msg, state);
+  int res = tlv_data->read_tlv(hdr, msg, state);
   if (res != PARSE_SUCCESS)
     return res;
 
@@ -1330,6 +1385,7 @@ static void
 babel_process_packet(struct babel_pkt_header *pkt, int len,
                      ip_addr saddr, struct babel_iface *ifa)
 {
+  u8 frame_err UNUSED = 0;
   struct babel_proto *p = ifa->proto;
   struct babel_tlv *tlv;
   struct babel_msg_node *msg;
@@ -1337,15 +1393,16 @@ babel_process_packet(struct babel_pkt_header *pkt, int len,
   int res;
 
   int plen = sizeof(struct babel_pkt_header) + get_u16(&pkt->length);
-  byte *pos;
   byte *end = (byte *)pkt + plen;
 
   struct babel_parse_state state = {
-    .proto       = p,
-    .ifa         = ifa,
-    .saddr       = saddr,
-    .next_hop_ip6 = saddr,
-    .sadr_enabled = babel_sadr_enabled(p),
+    .get_tlv_data    = &get_packet_tlv_data,
+    .get_subtlv_data = &get_packet_subtlv_data,
+    .proto           = p,
+    .ifa             = ifa,
+    .saddr           = saddr,
+    .next_hop_ip6    = saddr,
+    .sadr_enabled    = babel_sadr_enabled(p),
   };
 
   if ((pkt->magic != BABEL_MAGIC) || (pkt->version != BABEL_VERSION))
@@ -1369,23 +1426,8 @@ babel_process_packet(struct babel_pkt_header *pkt, int len,
 
   /* First pass through the packet TLV by TLV, parsing each into internal data
      structures. */
-  for (tlv = FIRST_TLV(pkt);
-       (byte *)tlv < end;
-       tlv = NEXT_TLV(tlv))
+  WALK_TLVS(FIRST_TLV(pkt), end, tlv, frame_err, saddr, ifa->iface->name)
   {
-    /* Ugly special case */
-    if (tlv->type == BABEL_TLV_PAD1)
-      continue;
-
-    /* The end of the common TLV header */
-    pos = (byte *)tlv + sizeof(struct babel_tlv);
-    if ((pos > end) || (pos + tlv->length > end))
-    {
-      LOG_PKT("Bad TLV from %I via %s type %d pos %d - framing error",
-             saddr, ifa->iface->name, tlv->type, (byte *)tlv - (byte *)pkt);
-      break;
-    }
-
     msg = sl_allocz(p->msg_slab);
     res = babel_read_tlv(tlv, &msg->msg, &state);
     if (res == PARSE_SUCCESS)
@@ -1405,8 +1447,9 @@ babel_process_packet(struct babel_pkt_header *pkt, int len,
       break;
     }
   }
+  WALK_TLVS_END;
 
-  /* Parsing done, handle all parsed TLVs */
+  /* Parsing done, handle all parsed TLVs, regardless of any errors */
   WALK_LIST_FIRST(msg, msgs)
   {
     if (tlv_data[msg->msg.type].handle_tlv)