static int smb_direct_post_send_data(struct smbdirect_socket *sc,
struct smbdirect_send_batch *send_ctx,
struct iov_iter *iter,
- size_t *remaining_data_length);
+ u32 remaining_data_length);
static void smb_direct_send_immediate_work(struct work_struct *work)
{
if (sc->status != SMBDIRECT_SOCKET_CONNECTED)
return;
- smb_direct_post_send_data(sc, NULL, NULL, NULL);
+ smb_direct_post_send_data(sc, NULL, NULL, 0);
}
static struct smb_direct_transport *alloc_transport(struct rdma_cm_id *cm_id)
static int smb_direct_post_send_data(struct smbdirect_socket *sc,
struct smbdirect_send_batch *send_ctx,
struct iov_iter *iter,
- size_t *_remaining_data_length)
+ u32 remaining_data_length)
{
const struct smbdirect_socket_parameters *sp = &sc->parameters;
int ret;
struct smbdirect_send_io *msg;
struct smbdirect_data_transfer *packet;
size_t header_length;
- u32 remaining_data_length = 0;
u32 data_length = 0;
struct smbdirect_send_batch _send_ctx;
u16 new_credits;
if (iter) {
header_length = sizeof(struct smbdirect_data_transfer);
+ if (WARN_ON_ONCE(remaining_data_length == 0 ||
+ iov_iter_count(iter) > remaining_data_length))
+ return -EINVAL;
} else {
/* If this is a packet without payload, don't send padding */
header_length = offsetof(struct smbdirect_data_transfer, padding);
+ if (WARN_ON_ONCE(remaining_data_length))
+ return -EINVAL;
}
if (!send_ctx) {
new_credits = smbdirect_connection_grant_recv_credits(sc);
}
- if (iter)
- data_length = iov_iter_count(iter);
-
- if (_remaining_data_length) {
- *_remaining_data_length -= data_length;
- remaining_data_length = *_remaining_data_length;
- }
-
msg = smbdirect_connection_alloc_send_io(sc);
if (IS_ERR(msg)) {
ret = PTR_ERR(msg);
.local_dma_lkey = sc->ib.pd->local_dma_lkey,
.direction = DMA_TO_DEVICE,
};
+ size_t payload_len = umin(iov_iter_count(iter),
+ sp->max_send_size - sizeof(*packet));
- ret = smbdirect_map_sges_from_iter(iter, data_length, &extract);
+ ret = smbdirect_map_sges_from_iter(iter, payload_len, &extract);
if (ret < 0)
goto err;
- if (WARN_ON_ONCE(ret != data_length)) {
- ret = -EIO;
- goto err;
- }
+ data_length = ret;
+ remaining_data_length -= data_length;
msg->num_sge = extract.num_sge;
}
struct smb_direct_transport *st = SMBD_TRANS(t);
struct smbdirect_socket *sc = &st->socket;
struct smbdirect_socket_parameters *sp = &sc->parameters;
- size_t remaining_data_length;
- size_t iov_idx;
- size_t iov_ofs;
- size_t max_iov_size = sp->max_send_size -
- sizeof(struct smbdirect_data_transfer);
int ret;
struct smbdirect_send_batch send_ctx;
+ struct iov_iter iter;
int error = 0;
if (sc->status != SMBDIRECT_SOCKET_CONNECTED)
//FIXME: skip RFC1002 header..
if (WARN_ON_ONCE(niovs <= 1 || iov[0].iov_len != 4))
return -EINVAL;
- buflen -= 4;
- iov_idx = 1;
- iov_ofs = 0;
-
- remaining_data_length = buflen;
- ksmbd_debug(RDMA, "Sending smb (RDMA): smb_len=%u\n", buflen);
-
- smb_direct_send_ctx_init(&send_ctx, need_invalidate, remote_key);
- while (remaining_data_length) {
- struct kvec vecs[SMBDIRECT_SEND_IO_MAX_SGE - 1]; /* minus smbdirect hdr */
- size_t possible_bytes = max_iov_size;
- size_t possible_vecs;
- size_t bytes = 0;
- size_t nvecs = 0;
- struct iov_iter iter;
-
- /*
- * For the last message remaining_data_length should be
- * have been 0 already!
- */
- if (WARN_ON_ONCE(iov_idx >= niovs)) {
- error = -EINVAL;
- goto done;
- }
+ iov_iter_kvec(&iter, ITER_SOURCE, iov, niovs, buflen);
+ iov_iter_advance(&iter, 4);
- /*
- * We have 2 factors which limit the arguments we pass
- * to smb_direct_post_send_data():
- *
- * 1. The number of supported sges for the send,
- * while one is reserved for the smbdirect header.
- * And we currently need one SGE per page.
- * 2. The number of negotiated payload bytes per send.
- */
- possible_vecs = min_t(size_t, ARRAY_SIZE(vecs), niovs - iov_idx);
-
- while (iov_idx < niovs && possible_vecs && possible_bytes) {
- struct kvec *v = &vecs[nvecs];
- int page_count;
-
- v->iov_base = ((u8 *)iov[iov_idx].iov_base) + iov_ofs;
- v->iov_len = min_t(size_t,
- iov[iov_idx].iov_len - iov_ofs,
- possible_bytes);
- page_count = smbdirect_get_buf_page_count(v->iov_base, v->iov_len);
- if (page_count > possible_vecs) {
- /*
- * If the number of pages in the buffer
- * is to much (because we currently require
- * one SGE per page), we need to limit the
- * length.
- *
- * We know possible_vecs is at least 1,
- * so we always keep the first page.
- *
- * We need to calculate the number extra
- * pages (epages) we can also keep.
- *
- * We calculate the number of bytes in the
- * first page (fplen), this should never be
- * larger than v->iov_len because page_count is
- * at least 2, but adding a limitation feels
- * better.
- *
- * Then we calculate the number of bytes (elen)
- * we can keep for the extra pages.
- */
- size_t epages = possible_vecs - 1;
- size_t fpofs = offset_in_page(v->iov_base);
- size_t fplen = min_t(size_t, PAGE_SIZE - fpofs, v->iov_len);
- size_t elen = min_t(size_t, v->iov_len - fplen, epages*PAGE_SIZE);
-
- v->iov_len = fplen + elen;
- page_count = smbdirect_get_buf_page_count(v->iov_base, v->iov_len);
- if (WARN_ON_ONCE(page_count > possible_vecs)) {
- /*
- * Something went wrong in the above
- * logic...
- */
- error = -EINVAL;
- goto done;
- }
- }
- possible_vecs -= page_count;
- nvecs += 1;
- possible_bytes -= v->iov_len;
- bytes += v->iov_len;
-
- iov_ofs += v->iov_len;
- if (iov_ofs >= iov[iov_idx].iov_len) {
- iov_idx += 1;
- iov_ofs = 0;
- }
- }
+ /*
+ * The size must fit into the negotiated
+ * fragmented send size.
+ */
+ if (iov_iter_count(&iter) > sp->max_fragmented_send_size)
+ return -EMSGSIZE;
- iov_iter_kvec(&iter, ITER_SOURCE, vecs, nvecs, bytes);
+ ksmbd_debug(RDMA, "Sending smb (RDMA): smb_len=%zu\n",
+ iov_iter_count(&iter));
- ret = smb_direct_post_send_data(sc, &send_ctx,
- &iter, &remaining_data_length);
+ smb_direct_send_ctx_init(&send_ctx, need_invalidate, remote_key);
+ while (iov_iter_count(&iter)) {
+ ret = smb_direct_post_send_data(sc,
+ &send_ctx,
+ &iter,
+ iov_iter_count(&iter));
if (unlikely(ret)) {
error = ret;
- goto done;
+ break;
}
}
-done:
ret = smb_direct_flush_send_list(sc, &send_ctx, true);
if (unlikely(!ret && error))
ret = error;