]> git.ipfire.org Git - thirdparty/libvirt.git/commitdiff
rpc: Switch to dynamically allocated message buffer
authorMichal Privoznik <mprivozn@redhat.com>
Thu, 26 Apr 2012 15:21:24 +0000 (17:21 +0200)
committerMichal Privoznik <mprivozn@redhat.com>
Tue, 5 Jun 2012 15:48:40 +0000 (17:48 +0200)
Currently, we are allocating buffer for RPC messages statically.
This is not such pain when RPC limits are small. However, if we want
ever to increase those limits, we need to allocate buffer dynamically,
based on RPC message len (= the first 4 bytes). Therefore we will
decrease our mem usage in most cases and still be flexible enough in
corner cases.

src/rpc/virnetclient.c
src/rpc/virnetmessage.c
src/rpc/virnetmessage.h
src/rpc/virnetserverclient.c
tests/virnetmessagetest.c

index d88288d92012e9f84876af90e43ae082d57a5169..14f806f99b2d89a470f50a4eb4ce999ff2f5dbfd 100644 (file)
@@ -801,7 +801,12 @@ virNetClientCallDispatchReply(virNetClientPtr client)
         return -1;
     }
 
-    memcpy(thecall->msg->buffer, client->msg.buffer, sizeof(client->msg.buffer));
+    if (VIR_REALLOC_N(thecall->msg->buffer, client->msg.bufferLength) < 0) {
+        virReportOOMError();
+        return -1;
+    }
+
+    memcpy(thecall->msg->buffer, client->msg.buffer, client->msg.bufferLength);
     memcpy(&thecall->msg->header, &client->msg.header, sizeof(client->msg.header));
     thecall->msg->bufferLength = client->msg.bufferLength;
     thecall->msg->bufferOffset = client->msg.bufferOffset;
@@ -987,6 +992,7 @@ virNetClientIOWriteMessage(virNetClientPtr client,
         }
         thecall->msg->donefds = 0;
         thecall->msg->bufferOffset = thecall->msg->bufferLength = 0;
+        VIR_FREE(thecall->msg->buffer);
         if (thecall->expectReply)
             thecall->mode = VIR_NET_CLIENT_MODE_WAIT_RX;
         else
@@ -1030,8 +1036,13 @@ virNetClientIOReadMessage(virNetClientPtr client)
     ssize_t ret;
 
     /* Start by reading length word */
-    if (client->msg.bufferLength == 0)
+    if (client->msg.bufferLength == 0) {
         client->msg.bufferLength = 4;
+        if (VIR_ALLOC_N(client->msg.buffer, client->msg.bufferLength) < 0) {
+            virReportOOMError();
+            return -ENOMEM;
+        }
+    }
 
     wantData = client->msg.bufferLength - client->msg.bufferOffset;
 
@@ -1108,6 +1119,7 @@ virNetClientIOHandleInput(virNetClientPtr client)
 
                 ret = virNetClientCallDispatch(client);
                 client->msg.bufferOffset = client->msg.bufferLength = 0;
+                VIR_FREE(client->msg.buffer);
                 /*
                  * We've completed one call, but we don't want to
                  * spin around the loop forever if there are many
index 17ecc90eb708a569f075f84de142a5bb6518a2d5..dc4c2127a8c38116218405c7d79de2e0be6a3236 100644 (file)
@@ -61,6 +61,7 @@ void virNetMessageClear(virNetMessagePtr msg)
     for (i = 0 ; i < msg->nfds ; i++)
         VIR_FORCE_CLOSE(msg->fds[i]);
     VIR_FREE(msg->fds);
+    VIR_FREE(msg->buffer);
     memset(msg, 0, sizeof(*msg));
     msg->tracked = tracked;
 }
@@ -79,6 +80,7 @@ void virNetMessageFree(virNetMessagePtr msg)
 
     for (i = 0 ; i < msg->nfds ; i++)
         VIR_FORCE_CLOSE(msg->fds[i]);
+    VIR_FREE(msg->buffer);
     VIR_FREE(msg->fds);
     VIR_FREE(msg);
 }
@@ -144,6 +146,10 @@ int virNetMessageDecodeLength(virNetMessagePtr msg)
     /* Extend our declared buffer length and carry
        on reading the header + payload */
     msg->bufferLength += len;
+    if (VIR_REALLOC_N(msg->buffer, msg->bufferLength) < 0) {
+        virReportOOMError();
+        goto cleanup;
+    }
 
     VIR_DEBUG("Got length, now need %zu total (%u more)",
               msg->bufferLength, len);
@@ -212,7 +218,11 @@ int virNetMessageEncodeHeader(virNetMessagePtr msg)
     int ret = -1;
     unsigned int len = 0;
 
-    msg->bufferLength = sizeof(msg->buffer);
+    msg->bufferLength = VIR_NET_MESSAGE_MAX + VIR_NET_MESSAGE_LEN_MAX;
+    if (VIR_ALLOC_N(msg->buffer, msg->bufferLength) < 0) {
+        virReportOOMError();
+        goto cleanup;
+    }
     msg->bufferOffset = 0;
 
     /* Format the header. */
index c54e7c6ea55c35f1746c3f1909dfaf488f676846..8f36a70f3908c33aed009ffe8e589de08d2f80d4 100644 (file)
@@ -31,13 +31,10 @@ typedef virNetMessage *virNetMessagePtr;
 
 typedef void (*virNetMessageFreeCallback)(virNetMessagePtr msg, void *opaque);
 
-/* Never allocate this (huge) buffer on the stack. Always
- * use virNetMessageNew() to allocate on the heap
- */
 struct _virNetMessage {
     bool tracked;
 
-    char buffer[VIR_NET_MESSAGE_MAX + VIR_NET_MESSAGE_LEN_MAX];
+    char *buffer; /* Typically VIR_NET_MESSAGE_MAX + VIR_NET_MESSAGE_LEN_MAX */
     size_t bufferLength;
     size_t bufferOffset;
 
index 67600fd00af2f323f7a1054ab944deb90b0a1293..6ae4e252096adef256026d0b18ce2a56d6613cc6 100644 (file)
@@ -313,6 +313,11 @@ virNetServerClientCheckAccess(virNetServerClientPtr client)
      * (NB. The '\1' byte is sent in an encrypted record).
      */
     confirm->bufferLength = 1;
+    if (VIR_ALLOC_N(confirm->buffer, confirm->bufferLength) < 0) {
+        virReportOOMError();
+        virNetMessageFree(confirm);
+        return -1;
+    }
     confirm->bufferOffset = 0;
     confirm->buffer[0] = '\1';
 
@@ -373,6 +378,10 @@ virNetServerClientPtr virNetServerClientNew(virNetSocketPtr sock,
     if (!(client->rx = virNetMessageNew(true)))
         goto error;
     client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
+    if (VIR_ALLOC_N(client->rx->buffer, client->rx->bufferLength) < 0) {
+        virReportOOMError();
+        goto error;
+    }
     client->nrequests = 1;
 
     PROBE(RPC_SERVER_CLIENT_NEW,
@@ -922,7 +931,13 @@ readmore:
                 client->wantClose = true;
             } else {
                 client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
-                client->nrequests++;
+                if (VIR_ALLOC_N(client->rx->buffer,
+                                client->rx->bufferLength) < 0) {
+                    virReportOOMError();
+                    client->wantClose = true;
+                } else {
+                    client->nrequests++;
+                }
             }
         }
         virNetServerClientUpdateEvent(client);
@@ -1019,8 +1034,13 @@ virNetServerClientDispatchWrite(virNetServerClientPtr client)
                     client->nrequests < client->nrequests_max) {
                     /* Ready to recv more messages */
                     virNetMessageClear(msg);
+                    msg->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
+                    if (VIR_ALLOC_N(msg->buffer, msg->bufferLength) < 0) {
+                        virReportOOMError();
+                        virNetMessageFree(msg);
+                        return;
+                    }
                     client->rx = msg;
-                    client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
                     msg = NULL;
                     client->nrequests++;
                 }
index 28dc09f9ae2c9bcc434b198bd8cec320f9b3cd97..6c294ca4452cf43763cb81c0944e647b996d2206 100644 (file)
@@ -35,7 +35,7 @@
 
 static int testMessageHeaderEncode(const void *args ATTRIBUTE_UNUSED)
 {
-    static virNetMessage msg;
+    virNetMessagePtr msg = virNetMessageNew(true);
     static const char expect[] = {
         0x00, 0x00, 0x00, 0x1c,  /* Length */
         0x11, 0x22, 0x33, 0x44,  /* Program */
@@ -45,128 +45,153 @@ static int testMessageHeaderEncode(const void *args ATTRIBUTE_UNUSED)
         0x00, 0x00, 0x00, 0x99,  /* Serial */
         0x00, 0x00, 0x00, 0x00,  /* Status */
     };
-    memset(&msg, 0, sizeof(msg));
-
-    msg.header.prog = 0x11223344;
-    msg.header.vers = 0x01;
-    msg.header.proc = 0x666;
-    msg.header.type = VIR_NET_CALL;
-    msg.header.serial = 0x99;
-    msg.header.status = VIR_NET_OK;
+    /* According to doc to virNetMessageEncodeHeader(&msg):
+     * msg->buffer will be this long */
+    unsigned long msg_buf_size = VIR_NET_MESSAGE_MAX + VIR_NET_MESSAGE_LEN_MAX;
+    int ret = -1;
 
-    if (virNetMessageEncodeHeader(&msg) < 0)
+    if (!msg) {
+        virReportOOMError();
         return -1;
+    }
+
+    msg->header.prog = 0x11223344;
+    msg->header.vers = 0x01;
+    msg->header.proc = 0x666;
+    msg->header.type = VIR_NET_CALL;
+    msg->header.serial = 0x99;
+    msg->header.status = VIR_NET_OK;
+
+    if (virNetMessageEncodeHeader(msg) < 0)
+        goto cleanup;
 
-    if (ARRAY_CARDINALITY(expect) != msg.bufferOffset) {
+    if (ARRAY_CARDINALITY(expect) != msg->bufferOffset) {
         VIR_DEBUG("Expect message offset %zu got %zu",
-                  sizeof(expect), msg.bufferOffset);
-        return -1;
+                  sizeof(expect), msg->bufferOffset);
+        goto cleanup;
     }
 
-    if (msg.bufferLength != sizeof(msg.buffer)) {
+    if (msg->bufferLength != msg_buf_size) {
         VIR_DEBUG("Expect message offset %zu got %zu",
-                  sizeof(msg.buffer), msg.bufferLength);
-        return -1;
+                  msg_buf_size, msg->bufferLength);
+        goto cleanup;
     }
 
-    if (memcmp(expect, msg.buffer, sizeof(expect)) != 0) {
-        virtTestDifferenceBin(stderr, expect, msg.buffer, sizeof(expect));
-        return -1;
+    if (memcmp(expect, msg->buffer, sizeof(expect)) != 0) {
+        virtTestDifferenceBin(stderr, expect, msg->buffer, sizeof(expect));
+        goto cleanup;
     }
 
-    return 0;
+    ret = 0;
+cleanup:
+    virNetMessageFree(msg);
+    return ret;
 }
 
 static int testMessageHeaderDecode(const void *args ATTRIBUTE_UNUSED)
 {
-    static virNetMessage msg = {
-        .bufferOffset = 0,
-        .bufferLength = 0x4,
-        .buffer = {
-            0x00, 0x00, 0x00, 0x1c,  /* Length */
-            0x11, 0x22, 0x33, 0x44,  /* Program */
-            0x00, 0x00, 0x00, 0x01,  /* Version */
-            0x00, 0x00, 0x06, 0x66,  /* Procedure */
-            0x00, 0x00, 0x00, 0x01,  /* Type */
-            0x00, 0x00, 0x00, 0x99,  /* Serial */
-            0x00, 0x00, 0x00, 0x01,  /* Status */
-        },
-        .header = { 0, 0, 0, 0, 0, 0 },
+    virNetMessagePtr msg = virNetMessageNew(true);
+    static char input_buf [] =  {
+        0x00, 0x00, 0x00, 0x1c,  /* Length */
+        0x11, 0x22, 0x33, 0x44,  /* Program */
+        0x00, 0x00, 0x00, 0x01,  /* Version */
+        0x00, 0x00, 0x06, 0x66,  /* Procedure */
+        0x00, 0x00, 0x00, 0x01,  /* Type */
+        0x00, 0x00, 0x00, 0x99,  /* Serial */
+        0x00, 0x00, 0x00, 0x01,  /* Status */
     };
+    int ret = -1;
+
+    if (!msg) {
+        virReportOOMError();
+        return -1;
+    }
+
+    msg->bufferLength = 4;
+    if (VIR_ALLOC_N(msg->buffer, msg->bufferLength) < 0) {
+        virReportOOMError();
+        goto cleanup;
+    }
+    memcpy(msg->buffer, input_buf, msg->bufferLength);
 
-    msg.header.prog = 0x11223344;
-    msg.header.vers = 0x01;
-    msg.header.proc = 0x666;
-    msg.header.type = VIR_NET_CALL;
-    msg.header.serial = 0x99;
-    msg.header.status = VIR_NET_OK;
+    msg->header.prog = 0x11223344;
+    msg->header.vers = 0x01;
+    msg->header.proc = 0x666;
+    msg->header.type = VIR_NET_CALL;
+    msg->header.serial = 0x99;
+    msg->header.status = VIR_NET_OK;
 
-    if (virNetMessageDecodeLength(&msg) < 0) {
+    if (virNetMessageDecodeLength(msg) < 0) {
         VIR_DEBUG("Failed to decode message header");
-        return -1;
+        goto cleanup;
     }
 
-    if (msg.bufferOffset != 0x4) {
+    if (msg->bufferOffset != 0x4) {
         VIR_DEBUG("Expecting offset %zu got %zu",
-                  (size_t)4, msg.bufferOffset);
-        return -1;
+                  (size_t)4, msg->bufferOffset);
+        goto cleanup;
     }
 
-    if (msg.bufferLength != 0x1c) {
+    if (msg->bufferLength != 0x1c) {
         VIR_DEBUG("Expecting length %zu got %zu",
-                  (size_t)0x1c, msg.bufferLength);
-        return -1;
+                  (size_t)0x1c, msg->bufferLength);
+        goto cleanup;
     }
 
-    if (virNetMessageDecodeHeader(&msg) < 0) {
+    memcpy(msg->buffer, input_buf, msg->bufferLength);
+
+    if (virNetMessageDecodeHeader(msg) < 0) {
         VIR_DEBUG("Failed to decode message header");
-        return -1;
+        goto cleanup;
     }
 
-    if (msg.bufferOffset != msg.bufferLength) {
+    if (msg->bufferOffset != msg->bufferLength) {
         VIR_DEBUG("Expect message offset %zu got %zu",
-                  msg.bufferOffset, msg.bufferLength);
-        return -1;
+                  msg->bufferOffset, msg->bufferLength);
+        goto cleanup;
     }
 
-    if (msg.header.prog != 0x11223344) {
+    if (msg->header.prog != 0x11223344) {
         VIR_DEBUG("Expect prog %d got %d",
-                  0x11223344, msg.header.prog);
-        return -1;
+                  0x11223344, msg->header.prog);
+        goto cleanup;
     }
-    if (msg.header.vers != 0x1) {
+    if (msg->header.vers != 0x1) {
         VIR_DEBUG("Expect vers %d got %d",
-                  0x11223344, msg.header.vers);
-        return -1;
+                  0x11223344, msg->header.vers);
+        goto cleanup;
     }
-    if (msg.header.proc != 0x666) {
+    if (msg->header.proc != 0x666) {
         VIR_DEBUG("Expect proc %d got %d",
-                  0x666, msg.header.proc);
-        return -1;
+                  0x666, msg->header.proc);
+        goto cleanup;
     }
-    if (msg.header.type != VIR_NET_REPLY) {
+    if (msg->header.type != VIR_NET_REPLY) {
         VIR_DEBUG("Expect type %d got %d",
-                  VIR_NET_REPLY, msg.header.type);
-        return -1;
+                  VIR_NET_REPLY, msg->header.type);
+        goto cleanup;
     }
-    if (msg.header.serial != 0x99) {
+    if (msg->header.serial != 0x99) {
         VIR_DEBUG("Expect serial %d got %d",
-                  0x99, msg.header.serial);
-        return -1;
+                  0x99, msg->header.serial);
+        goto cleanup;
     }
-    if (msg.header.status != VIR_NET_ERROR) {
+    if (msg->header.status != VIR_NET_ERROR) {
         VIR_DEBUG("Expect status %d got %d",
-                  VIR_NET_ERROR, msg.header.status);
-        return -1;
+                  VIR_NET_ERROR, msg->header.status);
+        goto cleanup;
     }
 
-    return 0;
+    ret = 0;
+cleanup:
+    virNetMessageFree(msg);
+    return ret;
 }
 
 static int testMessagePayloadEncode(const void *args ATTRIBUTE_UNUSED)
 {
     virNetMessageError err;
-    static virNetMessage msg;
+    virNetMessagePtr msg = virNetMessageNew(true);
     int ret = -1;
     static const char expect[] = {
         0x00, 0x00, 0x00, 0x74,  /* Length */
@@ -200,7 +225,12 @@ static int testMessagePayloadEncode(const void *args ATTRIBUTE_UNUSED)
         0x00, 0x00, 0x00, 0x02,  /* Error int2 */
         0x00, 0x00, 0x00, 0x00,  /* Error network pointer */
     };
-    memset(&msg, 0, sizeof(msg));
+
+    if (!msg) {
+        virReportOOMError();
+        return -1;
+    }
+
     memset(&err, 0, sizeof(err));
 
     err.code = VIR_ERR_INTERNAL_ERROR;
@@ -223,33 +253,33 @@ static int testMessagePayloadEncode(const void *args ATTRIBUTE_UNUSED)
     err.int1 = 1;
     err.int2 = 2;
 
-    msg.header.prog = 0x11223344;
-    msg.header.vers = 0x01;
-    msg.header.proc = 0x666;
-    msg.header.type = VIR_NET_MESSAGE;
-    msg.header.serial = 0x99;
-    msg.header.status = VIR_NET_ERROR;
+    msg->header.prog = 0x11223344;
+    msg->header.vers = 0x01;
+    msg->header.proc = 0x666;
+    msg->header.type = VIR_NET_MESSAGE;
+    msg->header.serial = 0x99;
+    msg->header.status = VIR_NET_ERROR;
 
-    if (virNetMessageEncodeHeader(&msg) < 0)
+    if (virNetMessageEncodeHeader(msg) < 0)
         goto cleanup;
 
-    if (virNetMessageEncodePayload(&msg, (xdrproc_t)xdr_virNetMessageError, &err) < 0)
+    if (virNetMessageEncodePayload(msg, (xdrproc_t)xdr_virNetMessageError, &err) < 0)
         goto cleanup;
 
-    if (ARRAY_CARDINALITY(expect) != msg.bufferLength) {
+    if (ARRAY_CARDINALITY(expect) != msg->bufferLength) {
         VIR_DEBUG("Expect message length %zu got %zu",
-                  sizeof(expect), msg.bufferLength);
+                  sizeof(expect), msg->bufferLength);
         goto cleanup;
     }
 
-    if (msg.bufferOffset != 0) {
+    if (msg->bufferOffset != 0) {
         VIR_DEBUG("Expect message offset 0 got %zu",
-                  msg.bufferOffset);
+                  msg->bufferOffset);
         goto cleanup;
     }
 
-    if (memcmp(expect, msg.buffer, sizeof(expect)) != 0) {
-        virtTestDifferenceBin(stderr, expect, msg.buffer, sizeof(expect));
+    if (memcmp(expect, msg->buffer, sizeof(expect)) != 0) {
+        virtTestDifferenceBin(stderr, expect, msg->buffer, sizeof(expect));
         goto cleanup;
     }
 
@@ -267,166 +297,176 @@ cleanup:
     VIR_FREE(err.str1);
     VIR_FREE(err.str2);
     VIR_FREE(err.str3);
+    virNetMessageFree(msg);
     return ret;
 }
 
 static int testMessagePayloadDecode(const void *args ATTRIBUTE_UNUSED)
 {
     virNetMessageError err;
-    static virNetMessage msg = {
-        .bufferOffset = 0,
-        .bufferLength = 0x4,
-        .buffer = {
-            0x00, 0x00, 0x00, 0x74,  /* Length */
-            0x11, 0x22, 0x33, 0x44,  /* Program */
-            0x00, 0x00, 0x00, 0x01,  /* Version */
-            0x00, 0x00, 0x06, 0x66,  /* Procedure */
-            0x00, 0x00, 0x00, 0x02,  /* Type */
-            0x00, 0x00, 0x00, 0x99,  /* Serial */
-            0x00, 0x00, 0x00, 0x01,  /* Status */
-
-            0x00, 0x00, 0x00, 0x01,  /* Error code */
-            0x00, 0x00, 0x00, 0x07,  /* Error domain */
-            0x00, 0x00, 0x00, 0x01,  /* Error message pointer */
-            0x00, 0x00, 0x00, 0x0b,  /* Error message length */
-            'H', 'e', 'l', 'l',  /* Error message string */
-            'o', ' ', 'W', 'o',
-            'r', 'l', 'd', '\0',
-            0x00, 0x00, 0x00, 0x02,  /* Error level */
-            0x00, 0x00, 0x00, 0x00,  /* Error domain pointer */
-            0x00, 0x00, 0x00, 0x01,  /* Error str1 pointer */
-            0x00, 0x00, 0x00, 0x03,  /* Error str1 length */
-            'O', 'n', 'e', '\0',  /* Error str1 message */
-            0x00, 0x00, 0x00, 0x01,  /* Error str2 pointer */
-            0x00, 0x00, 0x00, 0x03,  /* Error str2 length */
-            'T', 'w', 'o', '\0',  /* Error str2 message */
-            0x00, 0x00, 0x00, 0x01,  /* Error str3 pointer */
-            0x00, 0x00, 0x00, 0x05,  /* Error str3 length */
-            'T', 'h', 'r', 'e',  /* Error str3 message */
-            'e', '\0', '\0', '\0',
-            0x00, 0x00, 0x00, 0x01,  /* Error int1 */
-            0x00, 0x00, 0x00, 0x02,  /* Error int2 */
-            0x00, 0x00, 0x00, 0x00,  /* Error network pointer */
-        },
-        .header = { 0, 0, 0, 0, 0, 0 },
+    virNetMessagePtr msg = virNetMessageNew(true);
+    static char input_buffer[] = {
+        0x00, 0x00, 0x00, 0x74,  /* Length */
+        0x11, 0x22, 0x33, 0x44,  /* Program */
+        0x00, 0x00, 0x00, 0x01,  /* Version */
+        0x00, 0x00, 0x06, 0x66,  /* Procedure */
+        0x00, 0x00, 0x00, 0x02,  /* Type */
+        0x00, 0x00, 0x00, 0x99,  /* Serial */
+        0x00, 0x00, 0x00, 0x01,  /* Status */
+
+        0x00, 0x00, 0x00, 0x01,  /* Error code */
+        0x00, 0x00, 0x00, 0x07,  /* Error domain */
+        0x00, 0x00, 0x00, 0x01,  /* Error message pointer */
+        0x00, 0x00, 0x00, 0x0b,  /* Error message length */
+        'H', 'e', 'l', 'l',  /* Error message string */
+        'o', ' ', 'W', 'o',
+        'r', 'l', 'd', '\0',
+        0x00, 0x00, 0x00, 0x02,  /* Error level */
+        0x00, 0x00, 0x00, 0x00,  /* Error domain pointer */
+        0x00, 0x00, 0x00, 0x01,  /* Error str1 pointer */
+        0x00, 0x00, 0x00, 0x03,  /* Error str1 length */
+        'O', 'n', 'e', '\0',  /* Error str1 message */
+        0x00, 0x00, 0x00, 0x01,  /* Error str2 pointer */
+        0x00, 0x00, 0x00, 0x03,  /* Error str2 length */
+        'T', 'w', 'o', '\0',  /* Error str2 message */
+        0x00, 0x00, 0x00, 0x01,  /* Error str3 pointer */
+        0x00, 0x00, 0x00, 0x05,  /* Error str3 length */
+        'T', 'h', 'r', 'e',  /* Error str3 message */
+        'e', '\0', '\0', '\0',
+        0x00, 0x00, 0x00, 0x01,  /* Error int1 */
+        0x00, 0x00, 0x00, 0x02,  /* Error int2 */
+        0x00, 0x00, 0x00, 0x00,  /* Error network pointer */
     };
+    int ret = -1;
+
+    msg->bufferLength = 4;
+    if (VIR_ALLOC_N(msg->buffer, msg->bufferLength) < 0) {
+        virReportOOMError();
+        goto cleanup;
+    }
+    memcpy(msg->buffer, input_buffer, msg->bufferLength);
     memset(&err, 0, sizeof(err));
 
-    if (virNetMessageDecodeLength(&msg) < 0) {
+    if (virNetMessageDecodeLength(msg) < 0) {
         VIR_DEBUG("Failed to decode message header");
-        return -1;
+        goto cleanup;
     }
 
-    if (msg.bufferOffset != 0x4) {
+    if (msg->bufferOffset != 0x4) {
         VIR_DEBUG("Expecting offset %zu got %zu",
-                  (size_t)4, msg.bufferOffset);
-        return -1;
+                  (size_t)4, msg->bufferOffset);
+        goto cleanup;
     }
 
-    if (msg.bufferLength != 0x74) {
+    if (msg->bufferLength != 0x74) {
         VIR_DEBUG("Expecting length %zu got %zu",
-                  (size_t)0x74, msg.bufferLength);
-        return -1;
+                  (size_t)0x74, msg->bufferLength);
+        goto cleanup;
     }
 
-    if (virNetMessageDecodeHeader(&msg) < 0) {
+    memcpy(msg->buffer, input_buffer, msg->bufferLength);
+
+    if (virNetMessageDecodeHeader(msg) < 0) {
         VIR_DEBUG("Failed to decode message header");
-        return -1;
+        goto cleanup;
     }
 
-    if (msg.bufferOffset != 28) {
+    if (msg->bufferOffset != 28) {
         VIR_DEBUG("Expect message offset %zu got %zu",
-                  msg.bufferOffset, (size_t)28);
-        return -1;
+                  msg->bufferOffset, (size_t)28);
+        goto cleanup;
     }
 
-    if (msg.bufferLength != 0x74) {
+    if (msg->bufferLength != 0x74) {
         VIR_DEBUG("Expecting length %zu got %zu",
-                  (size_t)0x1c, msg.bufferLength);
-        return -1;
+                  (size_t)0x1c, msg->bufferLength);
+        goto cleanup;
     }
 
-    if (virNetMessageDecodePayload(&msg, (xdrproc_t)xdr_virNetMessageError, &err) < 0) {
+    if (virNetMessageDecodePayload(msg, (xdrproc_t)xdr_virNetMessageError, &err) < 0) {
         VIR_DEBUG("Failed to decode message payload");
-        return -1;
+        goto cleanup;
     }
 
     if (err.code != VIR_ERR_INTERNAL_ERROR) {
         VIR_DEBUG("Expect code %d got %d",
                   VIR_ERR_INTERNAL_ERROR, err.code);
-        return -1;
+        goto cleanup;
     }
 
     if (err.domain != VIR_FROM_RPC) {
         VIR_DEBUG("Expect domain %d got %d",
                   VIR_ERR_RPC, err.domain);
-        return -1;
+        goto cleanup;
     }
 
     if (err.message == NULL ||
         STRNEQ(*err.message, "Hello World")) {
         VIR_DEBUG("Expect str1 'Hello World' got %s",
                   err.message ? *err.message : "(null)");
-        return -1;
+        goto cleanup;
     }
 
     if (err.dom != NULL) {
         VIR_DEBUG("Expect NULL dom");
-        return -1;
+        goto cleanup;
     }
 
     if (err.level != VIR_ERR_ERROR) {
         VIR_DEBUG("Expect leve %d got %d",
                   VIR_ERR_ERROR, err.level);
-        return -1;
+        goto cleanup;
     }
 
     if (err.str1 == NULL ||
         STRNEQ(*err.str1, "One")) {
         VIR_DEBUG("Expect str1 'One' got %s",
                   err.str1 ? *err.str1 : "(null)");
-        return -1;
+        goto cleanup;
     }
 
     if (err.str2 == NULL ||
         STRNEQ(*err.str2, "Two")) {
         VIR_DEBUG("Expect str3 'Two' got %s",
                   err.str2 ? *err.str2 : "(null)");
-        return -1;
+        goto cleanup;
     }
 
     if (err.str3 == NULL ||
         STRNEQ(*err.str3, "Three")) {
         VIR_DEBUG("Expect str3 'Three' got %s",
                   err.str3 ? *err.str3 : "(null)");
-        return -1;
+        goto cleanup;
     }
 
     if (err.int1 != 1) {
         VIR_DEBUG("Expect int1 1 got %d",
                   err.int1);
-        return -1;
+        goto cleanup;
     }
 
     if (err.int2 != 2) {
         VIR_DEBUG("Expect int2 2 got %d",
                   err.int2);
-        return -1;
+        goto cleanup;
     }
 
     if (err.net != NULL) {
         VIR_DEBUG("Expect NULL network");
-        return -1;
+        goto cleanup;
     }
 
+    ret = 0;
+cleanup:
     xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)&err);
-    return 0;
+    virNetMessageFree(msg);
+    return ret;
 }
 
 static int testMessagePayloadStreamEncode(const void *args ATTRIBUTE_UNUSED)
 {
     char stream[] = "The quick brown fox jumps over the lazy dog";
-    static virNetMessage msg;
+    virNetMessagePtr msg = virNetMessageNew(true);
     static const char expect[] = {
         0x00, 0x00, 0x00, 0x47,  /* Length */
         0x11, 0x22, 0x33, 0x44,  /* Program */
@@ -448,39 +488,42 @@ static int testMessagePayloadStreamEncode(const void *args ATTRIBUTE_UNUSED)
         'a', 'z', 'y', ' ',
         'd', 'o', 'g',
     };
-    memset(&msg, 0, sizeof(msg));
+    int ret = -1;
 
-    msg.header.prog = 0x11223344;
-    msg.header.vers = 0x01;
-    msg.header.proc = 0x666;
-    msg.header.type = VIR_NET_STREAM;
-    msg.header.serial = 0x99;
-    msg.header.status = VIR_NET_CONTINUE;
+    msg->header.prog = 0x11223344;
+    msg->header.vers = 0x01;
+    msg->header.proc = 0x666;
+    msg->header.type = VIR_NET_STREAM;
+    msg->header.serial = 0x99;
+    msg->header.status = VIR_NET_CONTINUE;
 
-    if (virNetMessageEncodeHeader(&msg) < 0)
-        return -1;
+    if (virNetMessageEncodeHeader(msg) < 0)
+        goto cleanup;
 
-    if (virNetMessageEncodePayloadRaw(&msg, stream, strlen(stream)) < 0)
-        return -1;
+    if (virNetMessageEncodePayloadRaw(msg, stream, strlen(stream)) < 0)
+        goto cleanup;
 
-    if (ARRAY_CARDINALITY(expect) != msg.bufferLength) {
+    if (ARRAY_CARDINALITY(expect) != msg->bufferLength) {
         VIR_DEBUG("Expect message length %zu got %zu",
-                  sizeof(expect), msg.bufferLength);
-        return -1;
+                  sizeof(expect), msg->bufferLength);
+        goto cleanup;
     }
 
-    if (msg.bufferOffset != 0) {
+    if (msg->bufferOffset != 0) {
         VIR_DEBUG("Expect message offset 0 got %zu",
-                  msg.bufferOffset);
-        return -1;
+                  msg->bufferOffset);
+        goto cleanup;
     }
 
-    if (memcmp(expect, msg.buffer, sizeof(expect)) != 0) {
-        virtTestDifferenceBin(stderr, expect, msg.buffer, sizeof(expect));
-        return -1;
+    if (memcmp(expect, msg->buffer, sizeof(expect)) != 0) {
+        virtTestDifferenceBin(stderr, expect, msg->buffer, sizeof(expect));
+        goto cleanup;
     }
 
-    return 0;
+    ret = 0;
+cleanup:
+    virNetMessageFree(msg);
+    return ret;
 }