]> git.ipfire.org Git - thirdparty/systemd.git/blobdiff - src/boot/efi/efi-string.c
tree-wide: use ASSERT_PTR more
[thirdparty/systemd.git] / src / boot / efi / efi-string.c
index 17372ae9d4476588f4d81346bf1a8414a6221c53..dfb2b06a75924609807c16f1c32e39f59259a6d3 100644 (file)
@@ -4,7 +4,14 @@
 #include <stdint.h>
 
 #include "efi-string.h"
-#include "macro-fundamental.h"
+
+#ifdef SD_BOOT
+#  include "util.h"
+#else
+#  include <stdlib.h>
+#  include "macro.h"
+#  define xmalloc(n) ASSERT_SE_PTR(malloc(n))
+#endif
 
 /* String functions for both char and char16_t that should behave the same way as their respective
  * counterpart in userspace. Where it makes sense, these accept NULL and do something sensible whereas
 DEFINE_STRNLEN(char, strnlen8);
 DEFINE_STRNLEN(char16_t, strnlen16);
 
-size_t strlen8(const char *s) {
-        return strnlen8(s, SIZE_MAX);
-}
-
-size_t strlen16(const char16_t *s) {
-        return strnlen16(s, SIZE_MAX);
-}
-
 #define TOLOWER(c)                                                \
         ({                                                        \
                 typeof(c) _c = (c);                               \
@@ -60,14 +59,13 @@ DEFINE_STRTOLOWER(char16_t, strtolower16);
                         return CMP(s1, s2);                  \
                                                              \
                 while (n > 0) {                              \
-                        int c1 = *s1;                        \
-                        int c2 = *s2;                        \
+                        type c1 = *s1, c2 = *s2;             \
                         if (tolower) {                       \
                                 c1 = TOLOWER(c1);            \
                                 c2 = TOLOWER(c2);            \
                         }                                    \
                         if (!c1 || c1 != c2)                 \
-                                return c1 - c2;              \
+                                return CMP(c1, c2);          \
                                                              \
                         s1++;                                \
                         s2++;                                \
@@ -79,22 +77,278 @@ DEFINE_STRTOLOWER(char16_t, strtolower16);
 
 DEFINE_STRNCASECMP(char, strncmp8, false);
 DEFINE_STRNCASECMP(char16_t, strncmp16, false);
+DEFINE_STRNCASECMP(char, strncasecmp8, true);
+DEFINE_STRNCASECMP(char16_t, strncasecmp16, true);
+
+#define DEFINE_STRCPY(type, name)                                     \
+        type *name(type * restrict dest, const type * restrict src) { \
+                type *ret = ASSERT_PTR(dest);                         \
+                                                                      \
+                if (!src) {                                           \
+                        *dest = '\0';                                 \
+                        return ret;                                   \
+                }                                                     \
+                                                                      \
+                while (*src) {                                        \
+                        *dest = *src;                                 \
+                        dest++;                                       \
+                        src++;                                        \
+                }                                                     \
+                                                                      \
+                *dest = '\0';                                         \
+                return ret;                                           \
+        }
+
+DEFINE_STRCPY(char, strcpy8);
+DEFINE_STRCPY(char16_t, strcpy16);
+
+#define DEFINE_STRCHR(type, name)                  \
+        type *name(const type *s, type c) {        \
+                if (!s)                            \
+                        return NULL;               \
+                                                   \
+                while (*s) {                       \
+                        if (*s == c)               \
+                                return (type *) s; \
+                        s++;                       \
+                }                                  \
+                                                   \
+                return NULL;                       \
+        }
+
+DEFINE_STRCHR(char, strchr8);
+DEFINE_STRCHR(char16_t, strchr16);
+
+#define DEFINE_STRNDUP(type, name, len_func)              \
+        type *name(const type *s, size_t n) {             \
+                if (!s)                                   \
+                        return NULL;                      \
+                                                          \
+                size_t len = len_func(s, n);              \
+                size_t size = len * sizeof(type);         \
+                                                          \
+                type *dup = xmalloc(size + sizeof(type)); \
+                efi_memcpy(dup, s, size);                 \
+                dup[len] = '\0';                          \
+                                                          \
+                return dup;                               \
+        }
+
+DEFINE_STRNDUP(char, xstrndup8, strnlen8);
+DEFINE_STRNDUP(char16_t, xstrndup16, strnlen16);
+
+/* Patterns are fnmatch-compatible (with reduced feature support). */
+static bool efi_fnmatch_internal(const char16_t *p, const char16_t *h, int max_depth) {
+        assert(p);
+        assert(h);
+
+        if (max_depth == 0)
+                return false;
+
+        for (;; p++, h++)
+                switch (*p) {
+                case '\0':
+                        /* End of pattern. Check that haystack is now empty. */
+                        return *h == '\0';
+
+                case '\\':
+                        p++;
+                        if (*p == '\0' || *p != *h)
+                                /* Trailing escape or no match. */
+                                return false;
+                        break;
+
+                case '?':
+                        if (*h == '\0')
+                                /* Early end of haystack. */
+                                return false;
+                        break;
+
+                case '*':
+                        /* No need to recurse for consecutive '*'. */
+                        while (*p == '*')
+                                p++;
+
+                        for (; *h != '\0'; h++)
+                                /* Try matching haystack with remaining pattern. */
+                                if (efi_fnmatch_internal(p, h, max_depth - 1))
+                                        return true;
+
+                        /* End of haystack. Pattern needs to be empty too for a match. */
+                        return *p == '\0';
+
+                case '[':
+                        if (*h == '\0')
+                                /* Early end of haystack. */
+                                return false;
+
+                        bool first = true, can_range = true, match = false;
+                        for (;; first = false) {
+                                p++;
+                                if (*p == '\0')
+                                        return false;
+
+                                if (*p == '\\') {
+                                        p++;
+                                        if (*p == '\0')
+                                                return false;
+                                        if (*p == *h)
+                                                match = true;
+                                        can_range = true;
+                                        continue;
+                                }
+
+                                /* End of set unless it's the first char. */
+                                if (*p == ']' && !first)
+                                        break;
 
-int strcmp8(const char *s1, const char *s2) {
-        return strncmp8(s1, s2, SIZE_MAX);
+                                /* Range pattern if '-' is not first or last in set. */
+                                if (*p == '-' && can_range && !first && *(p + 1) != ']') {
+                                        char16_t low = *(p - 1);
+                                        p++;
+                                        if (*p == '\\')
+                                                p++;
+                                        if (*p == '\0')
+                                                return false;
+
+                                        if (low <= *h && *h <= *p)
+                                                match = true;
+
+                                        /* Ranges cannot be chained: [a-c-f] == [-abcf] */
+                                        can_range = false;
+                                        continue;
+                                }
+
+                                if (*p == *h)
+                                        match = true;
+                                can_range = true;
+                        }
+
+                        if (!match)
+                                return false;
+                        break;
+
+                default:
+                        if (*p != *h)
+                                /* Single char mismatch. */
+                                return false;
+                }
 }
 
-int strcmp16(const char16_t *s1, const char16_t *s2) {
-        return strncmp16(s1, s2, SIZE_MAX);
+bool efi_fnmatch(const char16_t *pattern, const char16_t *haystack) {
+        return efi_fnmatch_internal(pattern, haystack, 32);
 }
 
-DEFINE_STRNCASECMP(char, strncasecmp8, true);
-DEFINE_STRNCASECMP(char16_t, strncasecmp16, true);
+#define DEFINE_PARSE_NUMBER(type, name)                                    \
+        bool name(const type *s, uint64_t *ret_u, const type **ret_tail) { \
+                assert(ret_u);                                             \
+                                                                           \
+                if (!s)                                                    \
+                        return false;                                      \
+                                                                           \
+                /* Need at least one digit. */                             \
+                if (*s < '0' || *s > '9')                                  \
+                        return false;                                      \
+                                                                           \
+                uint64_t u = 0;                                            \
+                while (*s >= '0' && *s <= '9') {                           \
+                        if (__builtin_mul_overflow(u, 10, &u))             \
+                                return false;                              \
+                        if (__builtin_add_overflow(u, *s - '0', &u))       \
+                                return false;                              \
+                        s++;                                               \
+                }                                                          \
+                                                                           \
+                if (!ret_tail && *s != '\0')                               \
+                        return false;                                      \
+                                                                           \
+                *ret_u = u;                                                \
+                if (ret_tail)                                              \
+                        *ret_tail = s;                                     \
+                return true;                                               \
+        }
+
+DEFINE_PARSE_NUMBER(char, parse_number8);
+DEFINE_PARSE_NUMBER(char16_t, parse_number16);
+
+int efi_memcmp(const void *p1, const void *p2, size_t n) {
+        const uint8_t *up1 = p1, *up2 = p2;
+        int r;
+
+        if (!p1 || !p2)
+                return CMP(p1, p2);
+
+        while (n > 0) {
+                r = CMP(*up1, *up2);
+                if (r != 0)
+                        return r;
+
+                up1++;
+                up2++;
+                n--;
+        }
 
-int strcasecmp8(const char *s1, const char *s2) {
-        return strncasecmp8(s1, s2, SIZE_MAX);
+        return 0;
 }
 
-int strcasecmp16(const char16_t *s1, const char16_t *s2) {
-        return strncasecmp16(s1, s2, SIZE_MAX);
+void *efi_memcpy(void * restrict dest, const void * restrict src, size_t n) {
+        if (!dest || !src || n == 0)
+                return dest;
+
+#ifdef SD_BOOT
+        /* The firmware-provided memcpy is likely optimized, so use that. The function is guaranteed to be
+         * available by the UEFI spec. We still make it depend on the boot services pointer being set just in
+         * case the compiler emits a call before it is available. */
+        if (_likely_(BS)) {
+                BS->CopyMem(dest, (void *) src, n);
+                return dest;
+        }
+#endif
+
+        uint8_t *d = dest;
+        const uint8_t *s = src;
+
+        while (n > 0) {
+                *d = *s;
+                d++;
+                s++;
+                n--;
+        }
+
+        return dest;
+}
+
+void *efi_memset(void *p, int c, size_t n) {
+        if (!p || n == 0)
+                return p;
+
+#ifdef SD_BOOT
+        /* See comment in efi_memcpy. Note that the signature has c and n swapped! */
+        if (_likely_(BS)) {
+                BS->SetMem(p, n, c);
+                return p;
+        }
+#endif
+
+        uint8_t *q = p;
+        while (n > 0) {
+                *q = c;
+                q++;
+                n--;
+        }
+
+        return p;
 }
+
+#ifdef SD_BOOT
+#  undef memcmp
+#  undef memcpy
+#  undef memset
+/* Provide the actual implementation for the builtins by providing aliases. These need to be marked as used,
+ * as otherwise the compiler might remove them but still emit calls, which would break when linking.
+ * To prevent a different linker error, we mark memcpy/memset as weak, because gnu-efi is currently
+ * providing them. */
+__attribute__((used, alias("efi_memcmp"))) int memcmp(const void *p1, const void *p2, size_t n);
+__attribute__((used, weak, alias("efi_memcpy"))) void *memcpy(void * restrict dest, const void * restrict src, size_t n);
+__attribute__((used, weak, alias("efi_memset"))) void *memset(void *p, int c, size_t n);
+#endif