#include <err.h>
#include <errno.h>
+#include <stdlib.h>
#include <unistd.h>
+#include <netinet/in.h>
+
+static int read_exact(int, unsigned char *, size_t);
+static int read_and_waste(int, unsigned char *, size_t, u_int64_t);
+static int get_octets(rtr_char);
+static void place_null_character(rtr_char *, size_t);
static int
-read_exact(int fd, unsigned char *buffer, size_t length)
+read_exact(int fd, unsigned char *buffer, size_t buffer_len)
{
- int n, m;
+ ssize_t read_result;
+ size_t offset;
int err;
- for (n = 0; n < length;) {
- m = read(fd, &buffer[n], length - n);
- if (m < 0) {
+ for (offset = 0; offset < buffer_len; offset += read_result) {
+ read_result = read(fd, &buffer[offset], buffer_len - offset);
+ if (read_result == -1) {
err = errno;
warn("Client socket read interrupted");
return err;
}
-
- if (m == 0 && n == 0) {
- /* Stream ended gracefully. */
- return 0;
- }
-
- if (m == 0) {
- err = -EPIPE;
+ if (read_result == 0) {
warn("Stream ended mid-PDU");
- return err;
+ return -EPIPE;
}
-
- n += m;
}
return 0;
|| read_int32(fd, &result->s6_addr32[3]);
}
-int
-read_string(int fd, char **result)
+/*
+ * Consumes precisely @total_len bytes from @fd.
+ * The first @str_len bytes are stored in @str.
+ *
+ * It is required that @str_len <= @total_len.
+ */
+static int
+read_and_waste(int fd, unsigned char *str, size_t str_len, u_int64_t total_len)
{
- u_int32_t length;
+#define TLEN 1024 /* "Trash length" */
+ unsigned char trash[TLEN];
+ size_t offset;
int err;
- err = read_int32(fd, &length);
+ err = read_exact(fd, str, str_len);
if (err)
return err;
+ for (offset = str_len; (offset + TLEN) < total_len; offset += TLEN) {
+ err = read_exact(fd, trash, TLEN);
+ if (err)
+ return err;
+ }
+
+ return read_exact(fd, trash, total_len - offset);
+#undef TLEN
+}
+
+#define EINVALID_UTF8 -0xFFFF
+
+/*
+ * Returns the length (in octets) of the UTF-8 code point that starts with octet
+ * @first_octet.
+ */
+static int
+get_octets(rtr_char first_octet)
+{
+ if ((first_octet & 0xC0) == 0)
+ return 1;
+ if ((first_octet >> 5) == 6) /* 0b110 */
+ return 2;
+ if ((first_octet >> 4) == 14) /* 0b1110 */
+ return 3;
+ if ((first_octet >> 3) == 30) /* 0b11110 */
+ return 4;
+ return EINVALID_UTF8;
+}
+
+/*
+ * This also sanitizes the string, BTW.
+ * (Because it places the null chara in the first invalid character.
+ * The rest is silently ignored.)
+ *
+ * TODO test the hell out of this.
+ */
+static void
+place_null_character(rtr_char *str, size_t len)
+{
+ rtr_char *null_chara_pos;
+ rtr_char *cursor;
+ int octet;
+ int octets;
+
/*
- * TODO the RFC doesn't say if the length is in bytes, code points or
- * graphemes...
+ * This could be optimized by noticing that all byte continuations in
+ * UTF-8 start with 0b10. This means that we could start from the end
+ * of the string and move left until we find a valid character.
+ * But if we do that, we'd lose the sanitization. So this is better
+ * methinks.
*/
- *result = NULL;
+
+ null_chara_pos = str;
+ cursor = str;
+
+ while (cursor < str + len) {
+ octets = get_octets(*cursor);
+ if (octets == EINVALID_UTF8)
+ break;
+ for (octet = 1; octet < octets; octet++) {
+ if (cursor >= str + len - 1 || cursor[1] >> 6 != 0x10)
+ break;
+ cursor++;
+ }
+
+ null_chara_pos = cursor;
+ }
+
+ *null_chara_pos = '\0';
+}
+
+int
+read_string(int fd, rtr_char **result)
+{
+ /* Actual string length claimed by the PDU, in octets. */
+ u_int32_t full_length32; /* Excludes the null chara */
+ u_int64_t full_length64; /* Includes the null chara */
+ /*
+ * Actual length that we allocate. Octets.
+ * This exists because there might be value in truncating the string;
+ * full_length is a fucking 32-bit integer for some reason.
+ * Note that, because this is UTF-8 we're dealing with, this might not
+ * necessarily end up being the actual octet length of the final string;
+ * since our truncation can land in the middle of a code point, the null
+ * character might need to be shifted left slightly.
+ */
+ size_t alloc_length; /* Includes the null chara */
+ rtr_char *str;
+ int err;
+
+ err = read_int32(fd, &full_length32);
+ if (err)
+ return err;
+ full_length64 = ((u_int64_t) full_length32) + 1;
+
+ alloc_length = (full_length64 > 4096) ? 4096 : full_length64;
+ str = malloc(alloc_length);
+ if (!str)
+ return -ENOMEM;
+
+ err = read_and_waste(fd, str, alloc_length - 1, full_length64);
+ if (err) {
+ free(str);
+ return err;
+ }
+
+ place_null_character(str, alloc_length);
+
+ *result = str;
return 0;
}