u16 flags;
/* initialised and used only by !msg send variants */
u16 buf_group;
+ unsigned mshot_len;
void __user *msg_control;
/* used only for send zerocopy */
struct io_kiocb *notif;
enum sr_retry_flags {
IORING_RECV_RETRY = (1U << 15),
IORING_RECV_PARTIAL_MAP = (1U << 14),
+ IORING_RECV_MSHOT_CAP = (1U << 13),
IORING_RECV_RETRY_CLEAR = IORING_RECV_RETRY | IORING_RECV_PARTIAL_MAP,
- IORING_RECV_NO_RETRY = IORING_RECV_RETRY | IORING_RECV_PARTIAL_MAP,
+ IORING_RECV_NO_RETRY = IORING_RECV_RETRY | IORING_RECV_PARTIAL_MAP |
+ IORING_RECV_MSHOT_CAP,
};
/*
req->flags &= ~REQ_F_BL_EMPTY;
sr->done_io = 0;
sr->flags &= ~IORING_RECV_RETRY_CLEAR;
- sr->len = 0; /* get from the provided buffer */
+ sr->len = sr->mshot_len;
}
static int io_net_import_vec(struct io_kiocb *req, struct io_async_msghdr *iomsg,
sr->buf_group = req->buf_index;
req->buf_list = NULL;
}
+ sr->mshot_len = 0;
if (sr->flags & IORING_RECV_MULTISHOT) {
if (!(req->flags & REQ_F_BUFFER_SELECT))
return -EINVAL;
if (sr->msg_flags & MSG_WAITALL)
return -EINVAL;
- if (req->opcode == IORING_OP_RECV && sr->len)
- return -EINVAL;
+ if (req->opcode == IORING_OP_RECV)
+ sr->mshot_len = sr->len;
req->flags |= REQ_F_APOLL_MULTISHOT;
}
if (sr->flags & IORING_RECVSEND_BUNDLE) {
issue_flags);
if (sr->flags & IORING_RECV_RETRY)
cflags = req->cqe.flags | (cflags & CQE_F_MASK);
+ if (sr->mshot_len && *ret >= sr->mshot_len)
+ sr->flags |= IORING_RECV_MSHOT_CAP;
/* bundle with no more immediate buffers, we're done */
if (req->flags & REQ_F_BL_EMPTY)
goto finish;
io_mshot_prep_retry(req, kmsg);
/* Known not-empty or unknown state, retry */
if (cflags & IORING_CQE_F_SOCK_NONEMPTY || kmsg->msg.msg_inq < 0) {
- if (sr->nr_multishot_loops++ < MULTISHOT_MAX_RETRY)
+ if (sr->nr_multishot_loops++ < MULTISHOT_MAX_RETRY &&
+ !(sr->flags & IORING_RECV_MSHOT_CAP)) {
return false;
+ }
/* mshot retries exceeded, force a requeue */
sr->nr_multishot_loops = 0;
+ sr->flags &= ~IORING_RECV_MSHOT_CAP;
if (issue_flags & IO_URING_F_MULTISHOT)
*ret = IOU_REQUEUE;
}
arg.mode |= KBUF_MODE_FREE;
}
- if (kmsg->msg.msg_inq > 1)
+ if (*len)
+ arg.max_len = *len;
+ else if (kmsg->msg.msg_inq > 1)
arg.max_len = min_not_zero(*len, kmsg->msg.msg_inq);
ret = io_buffers_peek(req, &arg);