]> git.ipfire.org Git - thirdparty/bind9.git/commitdiff
Stream DNS: avoid memory copying/buffer resizing when reading data
authorArtem Boldariev <artem@boldariev.com>
Mon, 26 Dec 2022 15:42:49 +0000 (17:42 +0200)
committerOndřej Surý <ondrej@isc.org>
Mon, 3 Apr 2023 13:31:46 +0000 (13:31 +0000)
This commit optimises isc_dnsstream_assembler_t in such a way that
memory copying and reallocation are avoided when receiving one or more
complete DNS messages at once. We try to handle the data from the
messages directly, without storing them in an intermediate memory
buffer.

lib/isc/include/isc/dnsstream.h

index aab4834aeaf9b37700eace5e9359dca170f2a55f..1e6a26b08697d7d54fef6a921fdaa1a98a134e32 100644 (file)
@@ -92,7 +92,11 @@ message (i.e. someone attempts to send us junk data).
 struct isc_dnsstream_assembler {
        isc_buffer_t dnsbuf; /*!< Internal buffer for assembling DNS
                                   messages. */
-       uint8_t                      buf[ISC_DNSSTREAM_STATIC_BUFFER_SIZE];
+       uint8_t       buf[ISC_DNSSTREAM_STATIC_BUFFER_SIZE];
+       isc_buffer_t *current; /*!< Pointer to the currently used data buffer.
+                                 Most of the time it point to the 'dnsbuf'
+                                 except when dealing with data in place (when
+                                 it points to a temporary buffer) */
        isc_dnsstream_assembler_cb_t onmsg_cb; /*!< Data processing callback. */
        void                        *cbarg;    /*!< Callback argument. */
        bool calling_cb; /*<! Callback calling marker. Used to detect recursive
@@ -233,6 +237,8 @@ isc_dnsstream_assembler_init(isc_dnsstream_assembler_t *restrict dnsasm,
 
        isc_buffer_init(&dnsasm->dnsbuf, dnsasm->buf, sizeof(dnsasm->buf));
        isc_buffer_setmctx(&dnsasm->dnsbuf, dnsasm->mctx);
+
+       dnsasm->current = &dnsasm->dnsbuf;
 }
 
 static inline void
@@ -248,6 +254,7 @@ isc_dnsstream_assembler_uninit(isc_dnsstream_assembler_t *restrict dnsasm) {
        if (dnsasm->mctx != NULL) {
                isc_mem_detach(&dnsasm->mctx);
        }
+       dnsasm->current = NULL;
 }
 
 static inline isc_dnsstream_assembler_t *
@@ -288,6 +295,20 @@ isc_dnsstream_assembler_setcb(isc_dnsstream_assembler_t *restrict dnsasm,
        dnsasm->cbarg = cbarg;
 }
 
+static inline bool
+isc__dnsstream_assembler_callcb(isc_dnsstream_assembler_t *restrict dnsasm,
+                               const isc_result_t result,
+                               isc_region_t *restrict region, void *userarg) {
+       bool ret;
+
+       dnsasm->result = result;
+       dnsasm->calling_cb = true;
+       ret = dnsasm->onmsg_cb(dnsasm, result, region, dnsasm->cbarg, userarg);
+       dnsasm->calling_cb = false;
+
+       return (ret);
+}
+
 static inline bool
 isc__dnsstream_assembler_handle_message(
        isc_dnsstream_assembler_t *restrict dnsasm, void *userarg) {
@@ -298,25 +319,21 @@ isc__dnsstream_assembler_handle_message(
 
        INSIST(dnsasm->calling_cb == false);
 
-       result = isc_buffer_peekuint16(&dnsasm->dnsbuf, &dnslen);
+       result = isc_buffer_peekuint16(dnsasm->current, &dnslen);
 
        switch (result) {
        case ISC_R_SUCCESS:
                if (dnslen == 0) {
-                       /* This didn't make much sense to me: */
-                       /* isc_buffer_remaininglength(&dnsasm->dnsbuf) >=
-                        * sizeof(uint16_t) && */
-
                        /*
                         * Someone seems to send us binary junk or output from
                         * /dev/zero
                         */
                        result = ISC_R_RANGE;
-                       isc_buffer_clear(&dnsasm->dnsbuf);
+                       isc_dnsstream_assembler_clear(dnsasm);
                        break;
                }
 
-               if (dnslen > (isc_buffer_remaininglength(&dnsasm->dnsbuf) -
+               if (dnslen > (isc_buffer_remaininglength(dnsasm->current) -
                              sizeof(uint16_t)))
                {
                        result = ISC_R_NOMORE;
@@ -329,27 +346,144 @@ isc__dnsstream_assembler_handle_message(
                UNREACHABLE();
        }
 
-       dnsasm->result = result;
-       dnsasm->calling_cb = true;
        if (result == ISC_R_SUCCESS) {
-               (void)isc_buffer_getuint16(&dnsasm->dnsbuf);
-               isc_buffer_remainingregion(&dnsasm->dnsbuf, &region);
+               (void)isc_buffer_getuint16(dnsasm->current);
+               isc_buffer_remainingregion(dnsasm->current, &region);
                region.length = dnslen;
-               cont = dnsasm->onmsg_cb(dnsasm, ISC_R_SUCCESS, &region,
-                                       dnsasm->cbarg, userarg);
-               if (isc_buffer_remaininglength(&dnsasm->dnsbuf) >= dnslen) {
-                       isc_buffer_forward(&dnsasm->dnsbuf, dnslen);
+               cont = isc__dnsstream_assembler_callcb(dnsasm, result, &region,
+                                                      userarg);
+               if (isc_buffer_remaininglength(dnsasm->current) >= dnslen) {
+                       isc_buffer_forward(dnsasm->current, dnslen);
                }
        } else {
                cont = false;
-               (void)dnsasm->onmsg_cb(dnsasm, result, NULL, dnsasm->cbarg,
-                                      userarg);
+               (void)isc__dnsstream_assembler_callcb(dnsasm, result, NULL,
+                                                     userarg);
        }
-       dnsasm->calling_cb = false;
 
        return (cont);
 }
 
+static inline void
+isc__dnsstream_assembler_processing(isc_dnsstream_assembler_t *restrict dnsasm,
+                                   void *userarg) {
+       while (isc__dnsstream_assembler_handle_message(dnsasm, userarg)) {
+               if (isc_buffer_remaininglength(dnsasm->current) == 0) {
+                       break;
+               }
+       }
+}
+
+static inline void
+isc__dnsstream_assembler_incoming_direct(
+       isc_dnsstream_assembler_t *restrict dnsasm, void *userarg,
+       void *restrict buf, const unsigned int            buf_size) {
+       isc_buffer_t data = { 0 };
+       isc_region_t remaining = { 0 };
+       INSIST(dnsasm->current == &dnsasm->dnsbuf);
+
+       isc_buffer_init(&data, buf, buf_size);
+       isc_buffer_add(&data, buf_size);
+
+       /*
+        * Replace the internal buffer within the assembler
+        * object with a temporary buffer referring to the
+        * passed data directly.
+        */
+       dnsasm->current = &data;
+
+       /* process the data internally */
+       isc__dnsstream_assembler_processing(dnsasm, userarg);
+
+       /* set the internal buffer back */
+       dnsasm->current = &dnsasm->dnsbuf;
+
+       isc_buffer_remainingregion(&data, &remaining);
+       if (remaining.length != 0) {
+               /*
+                * Some unprocessed data left - let's put it
+                * into the internal buffer for processing
+                * later.
+                */
+               isc_buffer_putmem(dnsasm->current, remaining.base,
+                                 remaining.length);
+       }
+}
+
+static inline bool
+isc__dnsstream_assembler_incoming_direct_non_empty(
+       isc_dnsstream_assembler_t *restrict dnsasm, void *userarg,
+       void *restrict buf, unsigned int                  buf_size) {
+       size_t   remaining;
+       uint16_t dnslen = 0;
+       size_t   until_complete = 0;
+       size_t   remaining_no_len;
+
+       if (isc_buffer_peekuint16(dnsasm->current, &dnslen) != ISC_R_SUCCESS) {
+               return (false);
+       }
+
+       remaining = isc_buffer_remaininglength(dnsasm->current);
+       remaining_no_len = remaining - sizeof(uint16_t);
+
+       /*
+        * We have data for more than one DNS message - that means that on
+        * previous iteration we stopped prematurely intentionally.
+        */
+       if (remaining_no_len >= dnslen) {
+               return (false);
+       }
+
+       /*
+        * At this point we know that we have incomplete message in the
+        * internal buffer, let's find how much data do we need to
+        * complete the message and then check if we have enough data to
+        * handle it.
+        */
+       until_complete = dnslen - remaining_no_len;
+
+       if (buf_size >= until_complete) {
+               bool     cont;
+               uint8_t *unprocessed_buf = NULL;
+               size_t   unprocessed_size;
+
+               isc_buffer_putmem(dnsasm->current, buf, until_complete);
+               unprocessed_buf = ((uint8_t *)buf + until_complete);
+               unprocessed_size = buf_size - until_complete;
+
+               /* handle the message */
+               cont = isc__dnsstream_assembler_handle_message(dnsasm, userarg);
+               isc_buffer_trycompact(dnsasm->current);
+
+               INSIST(isc_buffer_remaininglength(dnsasm->current) == 0);
+               if (unprocessed_size == 0) {
+                       return (true);
+               }
+
+               if (cont) {
+                       /*
+                        * The callback logic told us to continue processing
+                        * messages, let's try to process the rest directly.
+                        */
+                       isc__dnsstream_assembler_incoming_direct(
+                               dnsasm, userarg, unprocessed_buf,
+                               unprocessed_size);
+               } else {
+                       /*
+                        * The callback logic told us to stop, let's copy the
+                        * remaining data into the internal buffer to process it
+                        * later.
+                        */
+                       isc_buffer_putmem(dnsasm->current, unprocessed_buf,
+                                         unprocessed_size);
+               }
+
+               return (true);
+       }
+
+       return (false);
+}
+
 static inline void
 isc_dnsstream_assembler_incoming(isc_dnsstream_assembler_t *restrict dnsasm,
                                 void              *userarg, void *restrict buf,
@@ -360,16 +494,74 @@ isc_dnsstream_assembler_incoming(isc_dnsstream_assembler_t *restrict dnsasm,
        if (buf_size == 0) {
                INSIST(buf == NULL);
        } else {
+               size_t remaining;
+
                INSIST(buf != NULL);
-               isc_buffer_putmem(&dnsasm->dnsbuf, buf, buf_size);
-       }
 
-       while (isc__dnsstream_assembler_handle_message(dnsasm, userarg)) {
-               if (isc_buffer_remaininglength(&dnsasm->dnsbuf) == 0) {
-                       break;
+               remaining = isc_buffer_remaininglength(&dnsasm->dnsbuf);
+
+               if (remaining == 0) {
+                       /*
+                        * We can try to handle messages in-place (without
+                        * memory copying/re-allocation) in the case we have no
+                        * other data in the internal buffer and have received
+                        * one or more complete messages at once. This way we
+                        * can avoid copying memory into the assembler's
+                        * internal buffer.
+                        */
+                       isc__dnsstream_assembler_incoming_direct(
+                               dnsasm, userarg, buf, buf_size);
+                       return;
+               } else if (isc__dnsstream_assembler_incoming_direct_non_empty(
+                                  dnsasm, userarg, buf, buf_size))
+               {
+                       /*
+                        * We had incomplete message in the buffer, but received
+                        * enough data to handle it. After that we handle the
+.                       * rest (if any) of the messages directly without
+                        * copying into the internal buffer. Any data, belonging
+                        * to incomplete messages at the end of the buffer, was
+                        * copied into the internal buffer to be processed later
+                        * when receiving the next batch of data.
+                        */
+                       return;
+               } else if (remaining == 1 && buf_size > 0) {
+                       /* Mostly the same case as above, but we have incomplete
+                        * message length in the buffer and received at least
+                        * one byte to complete it.
+                        */
+                       void  *unprocessed_buf = NULL;
+                       size_t unprocessed_size;
+
+                       isc_buffer_putmem(dnsasm->current, buf, 1);
+                       unprocessed_buf = (uint8_t *)buf + 1;
+                       unprocessed_size = buf_size - 1;
+
+                       if (isc__dnsstream_assembler_incoming_direct_non_empty(
+                                   dnsasm, userarg, unprocessed_buf,
+                                   unprocessed_size))
+                       {
+                               return;
+                       }
+
+                       if (buf_size > 0) {
+                               isc_buffer_putmem(dnsasm->current,
+                                                 unprocessed_buf,
+                                                 unprocessed_size);
+                       }
+                       /* let's continue processing via the generic path */
+               } else {
+                       /*
+                        * Put the data into the internal buffer for
+                        * processing.
+                        */
+                       isc_buffer_putmem(dnsasm->current, buf, buf_size);
                }
        }
-       isc_buffer_trycompact(&dnsasm->dnsbuf);
+
+       isc__dnsstream_assembler_processing(dnsasm, userarg);
+
+       isc_buffer_trycompact(dnsasm->current);
 }
 
 static inline isc_result_t
@@ -385,13 +577,16 @@ isc_dnsstream_assembler_remaininglength(
        const isc_dnsstream_assembler_t *restrict dnsasm) {
        REQUIRE(dnsasm != NULL);
 
-       return (isc_buffer_remaininglength(&dnsasm->dnsbuf));
+       return (isc_buffer_remaininglength(dnsasm->current));
 }
 
 static inline void
 isc_dnsstream_assembler_clear(isc_dnsstream_assembler_t *restrict dnsasm) {
        REQUIRE(dnsasm != NULL);
 
-       isc_buffer_clear(&dnsasm->dnsbuf);
+       isc_buffer_clear(dnsasm->current);
+       if (dnsasm->current != &dnsasm->dnsbuf) {
+               isc_buffer_clear(&dnsasm->dnsbuf);
+       }
        dnsasm->result = ISC_R_UNSET;
 }