static int
virNetClientCallDispatch(virNetClientPtr client)
{
- size_t i;
- if (virNetMessageDecodeHeader(&client->msg) < 0)
- return -1;
-
PROBE(RPC_CLIENT_MSG_RX,
"client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",
client, client->msg.bufferLength,
switch (client->msg.header.type) {
case VIR_NET_REPLY: /* Normal RPC replies */
- return virNetClientCallDispatchReply(client);
-
case VIR_NET_REPLY_WITH_FDS: /* Normal RPC replies with FDs */
- if (virNetMessageDecodeNumFDs(&client->msg) < 0)
- return -1;
- for (i = 0 ; i < client->msg.nfds ; i++) {
- if ((client->msg.fds[i] = virNetSocketRecvFD(client->sock)) < 0)
- return -1;
- }
return virNetClientCallDispatchReply(client);
case VIR_NET_MESSAGE: /* Async notifications */
virNetClientIOWriteMessage(virNetClientPtr client,
virNetClientCallPtr thecall)
{
- ssize_t ret;
+ ssize_t ret = 0;
- ret = virNetSocketWrite(client->sock,
- thecall->msg->buffer + thecall->msg->bufferOffset,
- thecall->msg->bufferLength - thecall->msg->bufferOffset);
- if (ret <= 0)
- return ret;
+ if (thecall->msg->bufferOffset < thecall->msg->bufferLength) {
+ ret = virNetSocketWrite(client->sock,
+ thecall->msg->buffer + thecall->msg->bufferOffset,
+ thecall->msg->bufferLength - thecall->msg->bufferOffset);
+ if (ret <= 0)
+ return ret;
- thecall->msg->bufferOffset += ret;
+ thecall->msg->bufferOffset += ret;
+ }
if (thecall->msg->bufferOffset == thecall->msg->bufferLength) {
size_t i;
- for (i = 0 ; i < thecall->msg->nfds ; i++) {
- if (virNetSocketSendFD(client->sock, thecall->msg->fds[i]) < 0)
+ for (i = thecall->msg->donefds ; i < thecall->msg->nfds ; i++) {
+ int rv;
+ if ((rv = virNetSocketSendFD(client->sock, thecall->msg->fds[i])) < 0)
return -1;
+ if (rv == 0) /* Blocking */
+ return 0;
+ thecall->msg->donefds++;
}
+ thecall->msg->donefds = 0;
thecall->msg->bufferOffset = thecall->msg->bufferLength = 0;
if (thecall->expectReply)
thecall->mode = VIR_NET_CLIENT_MODE_WAIT_RX;
* EAGAIN
*/
for (;;) {
- ssize_t ret = virNetClientIOReadMessage(client);
+ ssize_t ret;
- if (ret < 0)
- return -1;
- if (ret == 0)
- return 0; /* Blocking on read */
+ if (client->msg.nfds == 0) {
+ ret = virNetClientIOReadMessage(client);
+
+ if (ret < 0)
+ return -1;
+ if (ret == 0)
+ return 0; /* Blocking on read */
+ }
/* Check for completion of our goal */
if (client->msg.bufferOffset == client->msg.bufferLength) {
* next iteration.
*/
} else {
+ if (virNetMessageDecodeHeader(&client->msg) < 0)
+ return -1;
+
+ if (client->msg.header.type == VIR_NET_REPLY_WITH_FDS) {
+ size_t i;
+ if (virNetMessageDecodeNumFDs(&client->msg) < 0)
+ return -1;
+
+ for (i = client->msg.donefds ; i < client->msg.nfds ; i++) {
+ int rv;
+ if ((rv = virNetSocketRecvFD(client->sock, &(client->msg.fds[i]))) < 0)
+ return -1;
+ if (rv == 0) /* Blocking */
+ break;
+ client->msg.donefds++;
+ }
+
+ if (client->msg.donefds < client->msg.nfds) {
+ /* Because DecodeHeader/NumFDs reset bufferOffset, we
+ * put it back to what it was, so everything works
+ * again next time we run this method
+ */
+ client->msg.bufferOffset = client->msg.bufferLength;
+ return 0; /* Blocking on more fds */
+ }
+ }
+
ret = virNetClientCallDispatch(client);
client->msg.bufferOffset = client->msg.bufferLength = 0;
/*
goto cleanup;
}
+ msg->donefds = 0;
if (msg->bufferLength)
call->mode = VIR_NET_CLIENT_MODE_WAIT_TX;
else
size_t nfds;
int *fds;
+ size_t donefds;
virNetMessagePtr next;
};
static void virNetServerClientDispatchRead(virNetServerClientPtr client)
{
readmore:
- if (virNetServerClientRead(client) < 0) {
- client->wantClose = true;
- return; /* Error */
+ if (client->rx->nfds == 0) {
+ if (virNetServerClientRead(client) < 0) {
+ client->wantClose = true;
+ return; /* Error */
+ }
}
if (client->rx->bufferOffset < client->rx->bufferLength)
goto readmore;
} else {
/* Grab the completed message */
- virNetMessagePtr msg = virNetMessageQueueServe(&client->rx);
+ virNetMessagePtr msg = client->rx;
virNetServerClientFilterPtr filter;
size_t i;
return;
}
+ /* Now figure out if we need to read more data to get some
+ * file descriptors */
if (msg->header.type == VIR_NET_CALL_WITH_FDS &&
virNetMessageDecodeNumFDs(msg) < 0) {
virNetMessageFree(msg);
client->wantClose = true;
- return;
+ return; /* Error */
}
- for (i = 0 ; i < msg->nfds ; i++) {
- if ((msg->fds[i] = virNetSocketRecvFD(client->sock)) < 0) {
+
+ /* Try getting the file descriptors (may fail if blocking) */
+ for (i = msg->donefds ; i < msg->nfds ; i++) {
+ int rv;
+ if ((rv = virNetSocketRecvFD(client->sock, &(msg->fds[i]))) < 0) {
virNetMessageFree(msg);
client->wantClose = true;
return;
}
+ if (rv == 0) /* Blocking */
+ break;
+ msg->donefds++;
+ }
+
+ /* Need to poll() until FDs arrive */
+ if (msg->donefds < msg->nfds) {
+ /* Because DecodeHeader/NumFDs reset bufferOffset, we
+ * put it back to what it was, so everything works
+ * again next time we run this method
+ */
+ client->rx->bufferOffset = client->rx->bufferLength;
+ return;
}
+ /* Definitely finished reading, so remove from queue */
+ virNetMessageQueueServe(&client->rx);
PROBE(RPC_SERVER_CLIENT_MSG_RX,
"client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",
client, msg->bufferLength,
virNetServerClientDispatchWrite(virNetServerClientPtr client)
{
while (client->tx) {
- ssize_t ret;
-
- ret = virNetServerClientWrite(client);
- if (ret < 0) {
- client->wantClose = true;
- return;
+ if (client->tx->bufferOffset < client->tx->bufferLength) {
+ ssize_t ret;
+ ret = virNetServerClientWrite(client);
+ if (ret < 0) {
+ client->wantClose = true;
+ return;
+ }
+ if (ret == 0)
+ return; /* Would block on write EAGAIN */
}
- if (ret == 0)
- return; /* Would block on write EAGAIN */
if (client->tx->bufferOffset == client->tx->bufferLength) {
virNetMessagePtr msg;
size_t i;
- for (i = 0 ; i < client->tx->nfds ; i++) {
- if (virNetSocketSendFD(client->sock, client->tx->fds[i]) < 0) {
+ for (i = client->tx->donefds ; i < client->tx->nfds ; i++) {
+ int rv;
+ if ((rv = virNetSocketSendFD(client->sock, client->tx->fds[i])) < 0) {
client->wantClose = true;
return;
}
+ if (rv == 0) /* Blocking */
+ return;
+ client->tx->donefds++;
}
#if HAVE_SASL
msg->bufferLength, msg->bufferOffset);
virNetServerClientLock(client);
+ msg->donefds = 0;
if (client->sock && !client->wantClose) {
PROBE(RPC_SERVER_CLIENT_MSG_TX_QUEUE,
"client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",
}
+/*
+ * Returns 1 if an FD was sent, 0 if it would block, -1 on error
+ */
int virNetSocketSendFD(virNetSocketPtr sock, int fd)
{
int ret = -1;
PROBE(RPC_SOCKET_SEND_FD,
"sock=%p fd=%d", sock, fd);
if (sendfd(sock->fd, fd) < 0) {
- virReportSystemError(errno,
- _("Failed to send file descriptor %d"),
- fd);
+ if (errno == EAGAIN)
+ ret = 0;
+ else
+ virReportSystemError(errno,
+ _("Failed to send file descriptor %d"),
+ fd);
goto cleanup;
}
- ret = 0;
+ ret = 1;
cleanup:
virMutexUnlock(&sock->lock);
}
-int virNetSocketRecvFD(virNetSocketPtr sock)
+/*
+ * Returns 1 if an FD was read, 0 if it would block, -1 on error
+ */
+int virNetSocketRecvFD(virNetSocketPtr sock, int *fd)
{
int ret = -1;
+
+ *fd = -1;
+
if (!virNetSocketHasPassFD(sock)) {
virNetError(VIR_ERR_INTERNAL_ERROR,
_("Receiving file descriptors is not supported on this socket"));
}
virMutexLock(&sock->lock);
- if ((ret = recvfd(sock->fd, O_CLOEXEC)) < 0) {
- virReportSystemError(errno, "%s",
- _("Failed to recv file descriptor"));
+ if ((*fd = recvfd(sock->fd, O_CLOEXEC)) < 0) {
+ if (errno == EAGAIN)
+ ret = 0;
+ else
+ virReportSystemError(errno, "%s",
+ _("Failed to recv file descriptor"));
goto cleanup;
}
PROBE(RPC_SOCKET_RECV_FD,
- "sock=%p fd=%d", sock, ret);
+ "sock=%p fd=%d", sock, *fd);
+ ret = 1;
cleanup:
virMutexUnlock(&sock->lock);
ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len);
int virNetSocketSendFD(virNetSocketPtr sock, int fd);
-int virNetSocketRecvFD(virNetSocketPtr sock);
+int virNetSocketRecvFD(virNetSocketPtr sock, int *fd);
void virNetSocketSetTLSSession(virNetSocketPtr sock,
virNetTLSSessionPtr sess);