]> git.ipfire.org Git - thirdparty/dracut.git/commitdiff
feat(skipcpio): speed up and harden skipcpio
authorHarald Hoyer <harald@redhat.com>
Wed, 24 Feb 2021 12:55:54 +0000 (13:55 +0100)
committerHarald Hoyer <harald@hoyer.xyz>
Thu, 25 Feb 2021 13:48:22 +0000 (14:48 +0100)
Before:
```
Benchmark #1: ./skipcpio/skipcpio test-5.10.15-200.fc33.x86_64.img >/dev/null
  Time (mean ± σ):     125.5 ms ±   0.9 ms    [User: 97.4 ms, System: 27.9 ms]
  Range (min … max):   124.8 ms … 129.4 ms    23 runs
```

After:
```
Benchmark #1: ./skipcpio/skipcpio test-5.10.15-200.fc33.x86_64.img >/dev/null
  Time (mean ± σ):      12.2 ms ±   0.3 ms    [User: 2.7 ms, System: 9.5 ms]
  Range (min … max):    11.7 ms …  13.6 ms    212 runs
```

Besides the speedup, skipcpio now parses the cpio header and is not
falsely ending when the early cpio payload contains `TRAILER!!!`.

Fixes: https://github.com/dracutdevs/dracut/issues/1123
skipcpio/skipcpio.c

index dea216c57e3839b40d816957bf3c57ee88c38a09..fd9b4e46ad77ec63f7d4e6ce50993a31ac840774 100644 (file)
 #define _GNU_SOURCE
 #endif
 
-#include <stdbool.h>
 #include <stdio.h>
 #include <stdlib.h>
-#include <unistd.h>
 #include <string.h>
 
+#define CPIO_MAGIC "070701"
 #define CPIO_END "TRAILER!!!"
-#define CPIO_ENDLEN (sizeof(CPIO_END)-1)
-
-static char buf[CPIO_ENDLEN * 2 + 1];
+#define CPIO_ENDLEN (sizeof(CPIO_END) - 1)
+
+#define CPIO_ALIGNMENT 4
+
+struct cpio_header {
+        char c_magic[6];
+        char c_ino[8];
+        char c_mode[8];
+        char c_uid[8];
+        char c_gid[8];
+        char c_nlink[8];
+        char c_mtime[8];
+        char c_filesize[8];
+        char c_dev_maj[8];
+        char c_dev_min[8];
+        char c_rdev_maj[8];
+        char c_rdev_min[8];
+        char c_namesize[8];
+        char c_chksum[8];
+} __attribute__((packed));
+
+struct buf_struct {
+        struct cpio_header h;
+        char filename[CPIO_ENDLEN];
+} __attribute__((packed));
+
+union buf_union {
+        struct buf_struct cpio;
+        char copy_buffer[2048];
+};
+
+static union buf_union buf;
+
+#define ALIGN_UP(n, a) (((n) + (a) - 1) & (~((a) - 1)))
 
 int main(int argc, char **argv)
 {
@@ -51,7 +81,7 @@ int main(int argc, char **argv)
                 exit(1);
         }
 
-        s = fread(buf, 6, 1, f);
+        s = fread(&buf.cpio, sizeof(buf.cpio), 1, f);
         if (s <= 0) {
                 fprintf(stderr, "Read error from file '%s'\n", argv[1]);
                 fclose(f);
@@ -60,25 +90,44 @@ int main(int argc, char **argv)
         fseek(f, 0, SEEK_SET);
 
         /* check, if this is a cpio archive */
-        if (buf[0] == '0' && buf[1] == '7' && buf[2] == '0' && buf[3] == '7' && buf[4] == '0' && buf[5] == '1') {
+        if (memcmp(buf.cpio.h.c_magic, CPIO_MAGIC, 6) == 0) {
+
                 long pos = 0;
 
-                /* Search for CPIO_END */
-                do {
-                        char *h;
-                        fseek(f, pos, SEEK_SET);
-                        buf[sizeof(buf) - 1] = 0;
-                        s = fread(buf, CPIO_ENDLEN, 2, f);
-                        if (s <= 0)
-                                break;
+                unsigned long filesize;
+                unsigned long filename_length;
 
-                        h = memmem(buf, sizeof(buf), CPIO_END, sizeof(CPIO_END));
-                        if (h) {
-                                pos = (h - buf) + pos + CPIO_ENDLEN;
+                do {
+                        // zero string, spilling into next unused field, to use strtol
+                        buf.cpio.h.c_chksum[0] = 0;
+                        filename_length = strtoul(buf.cpio.h.c_namesize, NULL, 16);
+                        pos = ALIGN_UP(pos + sizeof(struct cpio_header) + filename_length, CPIO_ALIGNMENT);
+
+                        // zero string, spilling into next unused field, to use strtol
+                        buf.cpio.h.c_dev_maj[0] = 0;
+                        filesize = strtoul(buf.cpio.h.c_filesize, NULL, 16);
+                        pos = ALIGN_UP(pos + filesize, CPIO_ALIGNMENT);
+
+                        if (filename_length == (CPIO_ENDLEN + 1)
+                            && strncmp(buf.cpio.filename, CPIO_END, CPIO_ENDLEN) == 0) {
                                 fseek(f, pos, SEEK_SET);
                                 break;
                         }
-                        pos += CPIO_ENDLEN;
+
+                        if (fseek(f, pos, SEEK_SET) != 0) {
+                                perror("fseek");
+                                exit(1);
+                        }
+
+                        if (fread(&buf.cpio, sizeof(buf.cpio), 1, f) != 1) {
+                                perror("fread");
+                                exit(1);
+                        }
+
+                        if (memcmp(buf.cpio.h.c_magic, CPIO_MAGIC, 6) != 0) {
+                                fprintf(stderr, "Corrupt CPIO archive!\n");
+                                exit(1);
+                        }
                 } while (!feof(f));
 
                 if (feof(f)) {
@@ -86,33 +135,33 @@ int main(int argc, char **argv)
                         fseek(f, 0, SEEK_SET);
                 } else {
                         /* skip zeros */
-                        while (!feof(f)) {
+                        do {
                                 size_t i;
 
-                                buf[sizeof(buf) - 1] = 0;
-                                s = fread(buf, 1, sizeof(buf) - 1, f);
+                                s = fread(buf.copy_buffer, 1, sizeof(buf.copy_buffer) - 1, f);
                                 if (s <= 0)
                                         break;
 
-                                for (i = 0; (i < s) && (buf[i] == 0); i++) ;
+                                for (i = 0; (i < s) && (buf.copy_buffer[i] == 0); i++) ;
 
-                                if (buf[i] != 0) {
+                                if (buf.copy_buffer[i] != 0) {
                                         pos += i;
+
                                         fseek(f, pos, SEEK_SET);
                                         break;
                                 }
 
                                 pos += s;
-                        }
+                        } while (!feof(f));
                 }
         }
         /* cat out the rest */
         while (!feof(f)) {
-                s = fread(buf, 1, sizeof(buf), f);
+                s = fread(buf.copy_buffer, 1, sizeof(buf.copy_buffer), f);
                 if (s <= 0)
                         break;
 
-                s = fwrite(buf, 1, s, stdout);
+                s = fwrite(buf.copy_buffer, 1, s, stdout);
                 if (s <= 0)
                         break;
         }