]> git.ipfire.org Git - thirdparty/freeradius-server.git/commitdiff
Add support for custom protocol-specific flag copy functions and comparators
authorArran Cudbard-Bell <a.cudbardb@freeradius.org>
Tue, 29 Oct 2024 11:22:58 +0000 (12:22 +0100)
committerArran Cudbard-Bell <a.cudbardb@freeradius.org>
Tue, 29 Oct 2024 11:40:22 +0000 (12:40 +0100)
src/lib/util/dict.h
src/lib/util/dict_ext.c

index 0f41171f98782be509a070d0f88a12337cbfeea2..42baa527c190ea52a6fec01ed5f616ec68553f74 100644 (file)
@@ -354,6 +354,28 @@ struct fr_dict_flag_parser_rule_s {
        bool                            needs_value;                    //!< This parsing flag must have a value.  Else we error.
 };
 
+/** Copy custom flags from one attribute to another
+ *
+ * @param[out] da_to           attribute to copy to.  Use for the talloc_ctx for any heap allocated flag values.
+ * @param[out] flags_to                protocol specific flags struct to copy to.
+ * @param[in] flags_from       protocol specific flags struct to copy from.
+ * @return
+ *  - 0 on success.
+ *  - -1 on error.
+ */
+typedef int (*fr_dict_flags_copy_func_t)(fr_dict_attr_t *da_to, void *flags_to, void *flags_from);
+
+/** Compare the protocol specific flags struct from two attributes
+ *
+ * @para[in] da_a      first attribute to compare.
+ * @para[in] da_b      second attribute to compare.
+ * @return
+ *  - 0 if the flags are equal.
+ *  - < 0 if da_a < da_b.
+ *  - > 0 if da_a > da_b.
+ */
+ typedef int (*fr_dict_flags_cmp_func_t)(fr_dict_attr_t const *da_a, fr_dict_attr_t const *da_b);
+
 /** Protocol specific custom flag definitnion
  *
  */
@@ -410,7 +432,11 @@ typedef struct {
                        size_t                          len;                    //!< Length of the flags field in the protocol
                                                                                ///< specific structure.
 
+                       fr_dict_flags_copy_func_t       copy;                   //!< Copy flags from one attribute to another.
+                                                                               ///< Called when copying attributes.
 
+                       fr_dict_flags_cmp_func_t        cmp;                    //!< Compare the flags from two attributes.
+                                                                               ///< Called when comparing attribute fields.
                } flags;
 
                fr_dict_attr_valid_func_t       valid;                  //!< Validation function to ensure that
@@ -492,6 +518,13 @@ int                        fr_dict_str_to_argv(char *str, char **argv, int max_argc);
 int                    fr_dict_attr_acopy_local(fr_dict_attr_t const *dst, fr_dict_attr_t const *src) CC_HINT(nonnull);
 /** @} */
 
+/** @name Dict accessors
+ *
+ * @{
+ */
+fr_dict_protocol_t const *fr_dict_protocol(fr_dict_t const *dict);
+/** @} */
+
 /** @name Unknown ephemeral attributes
  *
  * @{
@@ -587,6 +620,15 @@ static inline CC_HINT(nonnull) int8_t fr_dict_attr_cmp(fr_dict_attr_t const *a,
 static inline CC_HINT(nonnull) int8_t fr_dict_attr_cmp_fields(const fr_dict_attr_t *a, const fr_dict_attr_t *b)
 {
        int8_t ret;
+       fr_dict_protocol_t const *a_proto = fr_dict_protocol(a->dict);
+
+       /*
+        *      Technically this isn't a property of the attribute
+        *      but we need them to be the same to be able to
+        *      compare protocol specific flags successfully.
+        */
+       ret = CMP(a_proto, fr_dict_protocol(b->dict));
+       if (ret != 0) return ret;
 
        ret = CMP(a->attr, b->attr);
        if (ret != 0) return ret;
@@ -597,6 +639,11 @@ static inline CC_HINT(nonnull) int8_t fr_dict_attr_cmp_fields(const fr_dict_attr
        ret = CMP(fr_dict_vendor_num_by_da(a), fr_dict_vendor_num_by_da(b));
        if (ret != 0) return ret;
 
+       /*
+        *      Compare protocol specific flags
+        */
+       if (a_proto->attr.flags.cmp && (ret = a_proto->attr.flags.cmp(a, b))) return ret;
+
        return CMP(memcmp(&a->flags, &b->flags, sizeof(a->flags)), 0);
 }
 /** @} */
@@ -669,8 +716,6 @@ fr_dict_t const             *fr_dict_by_protocol_name(char const *name);
 
 fr_dict_t const                *fr_dict_by_protocol_num(unsigned int num);
 
-fr_dict_protocol_t const *fr_dict_protocol(fr_dict_t const *dict);
-
 fr_dict_attr_t const   *fr_dict_unlocal(fr_dict_attr_t const *da) CC_HINT(nonnull);
 
 fr_dict_t const                *fr_dict_by_da(fr_dict_attr_t const *da) CC_HINT(nonnull);
index 3e6f886c3818d962852ab08cda01f06917f39bb5..c6becb5143be9f8b8314a77e7d140081a7e0a6e7 100644 (file)
@@ -155,6 +155,58 @@ static int fr_dict_attr_ext_vendor_copy(UNUSED int ext,
        return -1;
 }
 
+static int dict_ext_protocol_specific_copy(UNUSED int ext,
+                                          TALLOC_CTX *dst_chunk,
+                                          void *dst_ext_ptr, size_t dst_ext_len,
+                                          TALLOC_CTX const *src_chunk,
+                                          void *src_ext_ptr, size_t src_ext_len)
+{
+       fr_dict_attr_t const *from = talloc_get_type_abort_const(src_chunk, fr_dict_attr_t);
+       fr_dict_protocol_t const *from_proto = fr_dict_protocol(from->dict);
+       fr_dict_attr_t *to = talloc_get_type_abort_const(dst_chunk, fr_dict_attr_t);
+       fr_dict_protocol_t const *to_proto = fr_dict_protocol(to->dict);
+
+       /*
+        *      Whilst it's not strictly disallowed, we can't do anything
+        *      sane without an N x N matrix of copy functions for different
+        *      protocols.  Maybe we should add that at some point, but for
+        *      now, just ignore the copy.
+        */
+       if (from->dict != to->dict) return 0;
+
+       /*
+        *      Sanity checks...
+        */
+       if (unlikely(from_proto->attr.flags.len != src_ext_len)) {
+               fr_strerror_printf("Protocol specific extension length mismatch in source attribute %s.  Expected %zu, got %zu",
+                                  from->name,
+                                  fr_dict_protocol(from->dict)->attr.flags.len, fr_dict_protocol(to->dict)->attr.flags.len);
+               return -1;
+       }
+
+       if (unlikely(to_proto->attr.flags.len != dst_ext_len)) {
+               fr_strerror_printf("Protocol specific extension length mismatch in destintion attribute %s.  Expected %zu, got %zu",
+                                  to->name,
+                                  fr_dict_protocol(to->dict)->attr.flags.len, fr_dict_protocol(to->dict)->attr.flags.len);
+               return -1;
+       }
+
+       /*
+        *      The simple case... No custom copy function, just memcpy
+        */
+       if (!to_proto->attr.flags.copy) {
+               memcpy(dst_ext_ptr, src_ext_ptr, src_ext_len);
+               return 0;
+       }
+
+       /*
+        *      Call the custom copy function.  This is only needed if
+        *      there are heap allocated values, like strings, which
+        *      need copying from sources flags to the destination.
+        */
+       return to_proto->attr.flags.copy(dst_chunk, dst_ext_ptr, src_ext_ptr);
+}
+
 /** Holds additional information about extension structures
  *
  */
@@ -200,7 +252,8 @@ fr_ext_t const fr_dict_attr_ext_def = {
                [FR_DICT_ATTR_EXT_PROTOCOL_SPECIFIC] = {
                                                        .min = FR_EXT_ALIGNMENT,        /* allow for one byte of protocol stuff */
                                                        .has_hdr = true,                /* variable sized */
-                                                       .can_copy = false               /* only the protocol can copy it */
+                                                       .copy = dict_ext_protocol_specific_copy,
+                                                       .can_copy = true                /* Use the attr.flags.copy function */
                                                },
                [FR_DICT_ATTR_EXT_MAX]          = {}
        }