--- /dev/null
+From 00c0b1b3723f51d243538ace6661a31c4e279dc1 Mon Sep 17 00:00:00 2001
+From: Sasha Levin <sashal@kernel.org>
+Date: Thu, 14 Nov 2019 10:57:40 +0100
+Subject: vsock/virtio: add transport parameter to the
+ virtio_transport_reset_no_sock()
+
+From: Stefano Garzarella <sgarzare@redhat.com>
+
+[ Upstream commit 4c7246dc45e2706770d5233f7ce1597a07e069ba ]
+
+We are going to add 'struct vsock_sock *' parameter to
+virtio_transport_get_ops().
+
+In some cases, like in the virtio_transport_reset_no_sock(),
+we don't have any socket assigned to the packet received,
+so we can't use the virtio_transport_get_ops().
+
+In order to allow virtio_transport_reset_no_sock() to use the
+'.send_pkt' callback from the 'vhost_transport' or 'virtio_transport',
+we add the 'struct virtio_transport *' to it and to its caller:
+virtio_transport_recv_pkt().
+
+We moved the 'vhost_transport' and 'virtio_transport' definition,
+to pass their address to the virtio_transport_recv_pkt().
+
+Reviewed-by: Stefan Hajnoczi <stefanha@redhat.com>
+Signed-off-by: Stefano Garzarella <sgarzare@redhat.com>
+Signed-off-by: David S. Miller <davem@davemloft.net>
+Signed-off-by: Sasha Levin <sashal@kernel.org>
+---
+ drivers/vhost/vsock.c | 94 ++++++++---------
+ include/linux/virtio_vsock.h | 3 +-
+ net/vmw_vsock/virtio_transport.c | 131 ++++++++++++++----------
+ net/vmw_vsock/virtio_transport_common.c | 12 +--
+ 4 files changed, 134 insertions(+), 106 deletions(-)
+
+diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
+index 2ac966400c428..554e131d17b3b 100644
+--- a/drivers/vhost/vsock.c
++++ b/drivers/vhost/vsock.c
+@@ -349,6 +349,52 @@ static bool vhost_vsock_more_replies(struct vhost_vsock *vsock)
+ return val < vq->num;
+ }
+
++static struct virtio_transport vhost_transport = {
++ .transport = {
++ .get_local_cid = vhost_transport_get_local_cid,
++
++ .init = virtio_transport_do_socket_init,
++ .destruct = virtio_transport_destruct,
++ .release = virtio_transport_release,
++ .connect = virtio_transport_connect,
++ .shutdown = virtio_transport_shutdown,
++ .cancel_pkt = vhost_transport_cancel_pkt,
++
++ .dgram_enqueue = virtio_transport_dgram_enqueue,
++ .dgram_dequeue = virtio_transport_dgram_dequeue,
++ .dgram_bind = virtio_transport_dgram_bind,
++ .dgram_allow = virtio_transport_dgram_allow,
++
++ .stream_enqueue = virtio_transport_stream_enqueue,
++ .stream_dequeue = virtio_transport_stream_dequeue,
++ .stream_has_data = virtio_transport_stream_has_data,
++ .stream_has_space = virtio_transport_stream_has_space,
++ .stream_rcvhiwat = virtio_transport_stream_rcvhiwat,
++ .stream_is_active = virtio_transport_stream_is_active,
++ .stream_allow = virtio_transport_stream_allow,
++
++ .notify_poll_in = virtio_transport_notify_poll_in,
++ .notify_poll_out = virtio_transport_notify_poll_out,
++ .notify_recv_init = virtio_transport_notify_recv_init,
++ .notify_recv_pre_block = virtio_transport_notify_recv_pre_block,
++ .notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue,
++ .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
++ .notify_send_init = virtio_transport_notify_send_init,
++ .notify_send_pre_block = virtio_transport_notify_send_pre_block,
++ .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue,
++ .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
++
++ .set_buffer_size = virtio_transport_set_buffer_size,
++ .set_min_buffer_size = virtio_transport_set_min_buffer_size,
++ .set_max_buffer_size = virtio_transport_set_max_buffer_size,
++ .get_buffer_size = virtio_transport_get_buffer_size,
++ .get_min_buffer_size = virtio_transport_get_min_buffer_size,
++ .get_max_buffer_size = virtio_transport_get_max_buffer_size,
++ },
++
++ .send_pkt = vhost_transport_send_pkt,
++};
++
+ static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
+ {
+ struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
+@@ -402,7 +448,7 @@ static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
+ if (le64_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid &&
+ le64_to_cpu(pkt->hdr.dst_cid) ==
+ vhost_transport_get_local_cid())
+- virtio_transport_recv_pkt(pkt);
++ virtio_transport_recv_pkt(&vhost_transport, pkt);
+ else
+ virtio_transport_free_pkt(pkt);
+
+@@ -745,52 +791,6 @@ static struct miscdevice vhost_vsock_misc = {
+ .fops = &vhost_vsock_fops,
+ };
+
+-static struct virtio_transport vhost_transport = {
+- .transport = {
+- .get_local_cid = vhost_transport_get_local_cid,
+-
+- .init = virtio_transport_do_socket_init,
+- .destruct = virtio_transport_destruct,
+- .release = virtio_transport_release,
+- .connect = virtio_transport_connect,
+- .shutdown = virtio_transport_shutdown,
+- .cancel_pkt = vhost_transport_cancel_pkt,
+-
+- .dgram_enqueue = virtio_transport_dgram_enqueue,
+- .dgram_dequeue = virtio_transport_dgram_dequeue,
+- .dgram_bind = virtio_transport_dgram_bind,
+- .dgram_allow = virtio_transport_dgram_allow,
+-
+- .stream_enqueue = virtio_transport_stream_enqueue,
+- .stream_dequeue = virtio_transport_stream_dequeue,
+- .stream_has_data = virtio_transport_stream_has_data,
+- .stream_has_space = virtio_transport_stream_has_space,
+- .stream_rcvhiwat = virtio_transport_stream_rcvhiwat,
+- .stream_is_active = virtio_transport_stream_is_active,
+- .stream_allow = virtio_transport_stream_allow,
+-
+- .notify_poll_in = virtio_transport_notify_poll_in,
+- .notify_poll_out = virtio_transport_notify_poll_out,
+- .notify_recv_init = virtio_transport_notify_recv_init,
+- .notify_recv_pre_block = virtio_transport_notify_recv_pre_block,
+- .notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue,
+- .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
+- .notify_send_init = virtio_transport_notify_send_init,
+- .notify_send_pre_block = virtio_transport_notify_send_pre_block,
+- .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue,
+- .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
+-
+- .set_buffer_size = virtio_transport_set_buffer_size,
+- .set_min_buffer_size = virtio_transport_set_min_buffer_size,
+- .set_max_buffer_size = virtio_transport_set_max_buffer_size,
+- .get_buffer_size = virtio_transport_get_buffer_size,
+- .get_min_buffer_size = virtio_transport_get_min_buffer_size,
+- .get_max_buffer_size = virtio_transport_get_max_buffer_size,
+- },
+-
+- .send_pkt = vhost_transport_send_pkt,
+-};
+-
+ static int __init vhost_vsock_init(void)
+ {
+ int ret;
+diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
+index 584f9a647ad4a..0860cf4ae0461 100644
+--- a/include/linux/virtio_vsock.h
++++ b/include/linux/virtio_vsock.h
+@@ -148,7 +148,8 @@ virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
+
+ void virtio_transport_destruct(struct vsock_sock *vsk);
+
+-void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt);
++void virtio_transport_recv_pkt(struct virtio_transport *t,
++ struct virtio_vsock_pkt *pkt);
+ void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt);
+ void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt);
+ u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 wanted);
+diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
+index 67aba63b5c96d..43f6c4240b2a8 100644
+--- a/net/vmw_vsock/virtio_transport.c
++++ b/net/vmw_vsock/virtio_transport.c
+@@ -271,58 +271,6 @@ static bool virtio_transport_more_replies(struct virtio_vsock *vsock)
+ return val < virtqueue_get_vring_size(vq);
+ }
+
+-static void virtio_transport_rx_work(struct work_struct *work)
+-{
+- struct virtio_vsock *vsock =
+- container_of(work, struct virtio_vsock, rx_work);
+- struct virtqueue *vq;
+-
+- vq = vsock->vqs[VSOCK_VQ_RX];
+-
+- mutex_lock(&vsock->rx_lock);
+-
+- if (!vsock->rx_run)
+- goto out;
+-
+- do {
+- virtqueue_disable_cb(vq);
+- for (;;) {
+- struct virtio_vsock_pkt *pkt;
+- unsigned int len;
+-
+- if (!virtio_transport_more_replies(vsock)) {
+- /* Stop rx until the device processes already
+- * pending replies. Leave rx virtqueue
+- * callbacks disabled.
+- */
+- goto out;
+- }
+-
+- pkt = virtqueue_get_buf(vq, &len);
+- if (!pkt) {
+- break;
+- }
+-
+- vsock->rx_buf_nr--;
+-
+- /* Drop short/long packets */
+- if (unlikely(len < sizeof(pkt->hdr) ||
+- len > sizeof(pkt->hdr) + pkt->len)) {
+- virtio_transport_free_pkt(pkt);
+- continue;
+- }
+-
+- pkt->len = len - sizeof(pkt->hdr);
+- virtio_transport_recv_pkt(pkt);
+- }
+- } while (!virtqueue_enable_cb(vq));
+-
+-out:
+- if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2)
+- virtio_vsock_rx_fill(vsock);
+- mutex_unlock(&vsock->rx_lock);
+-}
+-
+ /* event_lock must be held */
+ static int virtio_vsock_event_fill_one(struct virtio_vsock *vsock,
+ struct virtio_vsock_event *event)
+@@ -485,6 +433,85 @@ static struct virtio_transport virtio_transport = {
+ .send_pkt = virtio_transport_send_pkt,
+ };
+
++static void virtio_transport_loopback_work(struct work_struct *work)
++{
++ struct virtio_vsock *vsock =
++ container_of(work, struct virtio_vsock, loopback_work);
++ LIST_HEAD(pkts);
++
++ spin_lock_bh(&vsock->loopback_list_lock);
++ list_splice_init(&vsock->loopback_list, &pkts);
++ spin_unlock_bh(&vsock->loopback_list_lock);
++
++ mutex_lock(&vsock->rx_lock);
++
++ if (!vsock->rx_run)
++ goto out;
++
++ while (!list_empty(&pkts)) {
++ struct virtio_vsock_pkt *pkt;
++
++ pkt = list_first_entry(&pkts, struct virtio_vsock_pkt, list);
++ list_del_init(&pkt->list);
++
++ virtio_transport_recv_pkt(&virtio_transport, pkt);
++ }
++out:
++ mutex_unlock(&vsock->rx_lock);
++}
++
++static void virtio_transport_rx_work(struct work_struct *work)
++{
++ struct virtio_vsock *vsock =
++ container_of(work, struct virtio_vsock, rx_work);
++ struct virtqueue *vq;
++
++ vq = vsock->vqs[VSOCK_VQ_RX];
++
++ mutex_lock(&vsock->rx_lock);
++
++ if (!vsock->rx_run)
++ goto out;
++
++ do {
++ virtqueue_disable_cb(vq);
++ for (;;) {
++ struct virtio_vsock_pkt *pkt;
++ unsigned int len;
++
++ if (!virtio_transport_more_replies(vsock)) {
++ /* Stop rx until the device processes already
++ * pending replies. Leave rx virtqueue
++ * callbacks disabled.
++ */
++ goto out;
++ }
++
++ pkt = virtqueue_get_buf(vq, &len);
++ if (!pkt) {
++ break;
++ }
++
++ vsock->rx_buf_nr--;
++
++ /* Drop short/long packets */
++ if (unlikely(len < sizeof(pkt->hdr) ||
++ len > sizeof(pkt->hdr) + pkt->len)) {
++ virtio_transport_free_pkt(pkt);
++ continue;
++ }
++
++ pkt->len = len - sizeof(pkt->hdr);
++ virtio_transport_recv_pkt(&virtio_transport, pkt);
++ }
++ } while (!virtqueue_enable_cb(vq));
++
++out:
++ if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2)
++ virtio_vsock_rx_fill(vsock);
++ mutex_unlock(&vsock->rx_lock);
++}
++
+ static int virtio_vsock_probe(struct virtio_device *vdev)
+ {
+ vq_callback_t *callbacks[] = {
+diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
+index aa9d1c7780c3d..d64285afe68f3 100644
+--- a/net/vmw_vsock/virtio_transport_common.c
++++ b/net/vmw_vsock/virtio_transport_common.c
+@@ -599,9 +599,9 @@ static int virtio_transport_reset(struct vsock_sock *vsk,
+ /* Normally packets are associated with a socket. There may be no socket if an
+ * attempt was made to connect to a socket that does not exist.
+ */
+-static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
++static int virtio_transport_reset_no_sock(const struct virtio_transport *t,
++ struct virtio_vsock_pkt *pkt)
+ {
+- const struct virtio_transport *t;
+ struct virtio_vsock_pkt *reply;
+ struct virtio_vsock_pkt_info info = {
+ .op = VIRTIO_VSOCK_OP_RST,
+@@ -621,7 +621,6 @@ static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
+ if (!reply)
+ return -ENOMEM;
+
+- t = virtio_transport_get_ops();
+ if (!t) {
+ virtio_transport_free_pkt(reply);
+ return -ENOTCONN;
+@@ -919,7 +918,8 @@ static bool virtio_transport_space_update(struct sock *sk,
+ /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
+ * lock.
+ */
+-void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
++void virtio_transport_recv_pkt(struct virtio_transport *t,
++ struct virtio_vsock_pkt *pkt)
+ {
+ struct sockaddr_vm src, dst;
+ struct vsock_sock *vsk;
+@@ -941,7 +941,7 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
+ le32_to_cpu(pkt->hdr.fwd_cnt));
+
+ if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
+- (void)virtio_transport_reset_no_sock(pkt);
++ (void)virtio_transport_reset_no_sock(t, pkt);
+ goto free_pkt;
+ }
+
+@@ -952,7 +952,7 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
+ if (!sk) {
+ sk = vsock_find_bound_socket(&dst);
+ if (!sk) {
+- (void)virtio_transport_reset_no_sock(pkt);
++ (void)virtio_transport_reset_no_sock(t, pkt);
+ goto free_pkt;
+ }
+ }
+--
+2.25.1
+
--- /dev/null
+From 214a8e4fc925592ecf1a45b96718834b897cb70f Mon Sep 17 00:00:00 2001
+From: Sasha Levin <sashal@kernel.org>
+Date: Fri, 5 Jul 2019 13:04:53 +0200
+Subject: vsock/virtio: stop workers during the .remove()
+
+From: Stefano Garzarella <sgarzare@redhat.com>
+
+[ Upstream commit 17dd1367389cfe7f150790c83247b68e0c19d106 ]
+
+Before to call vdev->config->reset(vdev) we need to be sure that
+no one is accessing the device, for this reason, we add new variables
+in the struct virtio_vsock to stop the workers during the .remove().
+
+This patch also add few comments before vdev->config->reset(vdev)
+and vdev->config->del_vqs(vdev).
+
+Suggested-by: Stefan Hajnoczi <stefanha@redhat.com>
+Suggested-by: Michael S. Tsirkin <mst@redhat.com>
+Signed-off-by: Stefano Garzarella <sgarzare@redhat.com>
+Signed-off-by: David S. Miller <davem@davemloft.net>
+Signed-off-by: Sasha Levin <sashal@kernel.org>
+---
+ net/vmw_vsock/virtio_transport.c | 46 +++++++++++++++++++++++++++++++-
+ 1 file changed, 45 insertions(+), 1 deletion(-)
+
+diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
+index 32ad7cfa5fa74..67aba63b5c96d 100644
+--- a/net/vmw_vsock/virtio_transport.c
++++ b/net/vmw_vsock/virtio_transport.c
+@@ -39,6 +39,7 @@ struct virtio_vsock {
+ * must be accessed with tx_lock held.
+ */
+ struct mutex tx_lock;
++ bool tx_run;
+
+ struct work_struct send_pkt_work;
+ spinlock_t send_pkt_list_lock;
+@@ -50,6 +51,7 @@ struct virtio_vsock {
+ * must be accessed with rx_lock held.
+ */
+ struct mutex rx_lock;
++ bool rx_run;
+ int rx_buf_nr;
+ int rx_buf_max_nr;
+
+@@ -57,6 +59,7 @@ struct virtio_vsock {
+ * vqs[VSOCK_VQ_EVENT] must be accessed with event_lock held.
+ */
+ struct mutex event_lock;
++ bool event_run;
+ struct virtio_vsock_event event_list[8];
+
+ u32 guest_cid;
+@@ -91,6 +94,9 @@ virtio_transport_send_pkt_work(struct work_struct *work)
+
+ mutex_lock(&vsock->tx_lock);
+
++ if (!vsock->tx_run)
++ goto out;
++
+ vq = vsock->vqs[VSOCK_VQ_TX];
+
+ for (;;) {
+@@ -147,6 +153,7 @@ virtio_transport_send_pkt_work(struct work_struct *work)
+ if (added)
+ virtqueue_kick(vq);
+
++out:
+ mutex_unlock(&vsock->tx_lock);
+
+ if (restart_rx)
+@@ -230,6 +237,10 @@ static void virtio_transport_tx_work(struct work_struct *work)
+
+ vq = vsock->vqs[VSOCK_VQ_TX];
+ mutex_lock(&vsock->tx_lock);
++
++ if (!vsock->tx_run)
++ goto out;
++
+ do {
+ struct virtio_vsock_pkt *pkt;
+ unsigned int len;
+@@ -240,6 +251,8 @@ static void virtio_transport_tx_work(struct work_struct *work)
+ added = true;
+ }
+ } while (!virtqueue_enable_cb(vq));
++
++out:
+ mutex_unlock(&vsock->tx_lock);
+
+ if (added)
+@@ -268,6 +281,9 @@ static void virtio_transport_rx_work(struct work_struct *work)
+
+ mutex_lock(&vsock->rx_lock);
+
++ if (!vsock->rx_run)
++ goto out;
++
+ do {
+ virtqueue_disable_cb(vq);
+ for (;;) {
+@@ -376,6 +392,9 @@ static void virtio_transport_event_work(struct work_struct *work)
+
+ mutex_lock(&vsock->event_lock);
+
++ if (!vsock->event_run)
++ goto out;
++
+ do {
+ struct virtio_vsock_event *event;
+ unsigned int len;
+@@ -390,7 +409,7 @@ static void virtio_transport_event_work(struct work_struct *work)
+ } while (!virtqueue_enable_cb(vq));
+
+ virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]);
+-
++out:
+ mutex_unlock(&vsock->event_lock);
+ }
+
+@@ -521,12 +540,18 @@ static int virtio_vsock_probe(struct virtio_device *vdev)
+ INIT_WORK(&vsock->event_work, virtio_transport_event_work);
+ INIT_WORK(&vsock->send_pkt_work, virtio_transport_send_pkt_work);
+
++ mutex_lock(&vsock->tx_lock);
++ vsock->tx_run = true;
++ mutex_unlock(&vsock->tx_lock);
++
+ mutex_lock(&vsock->rx_lock);
+ virtio_vsock_rx_fill(vsock);
++ vsock->rx_run = true;
+ mutex_unlock(&vsock->rx_lock);
+
+ mutex_lock(&vsock->event_lock);
+ virtio_vsock_event_fill(vsock);
++ vsock->event_run = true;
+ mutex_unlock(&vsock->event_lock);
+
+ vdev->priv = vsock;
+@@ -560,6 +585,24 @@ static void virtio_vsock_remove(struct virtio_device *vdev)
+ /* Reset all connected sockets when the device disappear */
+ vsock_for_each_connected_socket(virtio_vsock_reset_sock);
+
++ /* Stop all work handlers to make sure no one is accessing the device,
++ * so we can safely call vdev->config->reset().
++ */
++ mutex_lock(&vsock->rx_lock);
++ vsock->rx_run = false;
++ mutex_unlock(&vsock->rx_lock);
++
++ mutex_lock(&vsock->tx_lock);
++ vsock->tx_run = false;
++ mutex_unlock(&vsock->tx_lock);
++
++ mutex_lock(&vsock->event_lock);
++ vsock->event_run = false;
++ mutex_unlock(&vsock->event_lock);
++
++ /* Flush all device writes and interrupts, device will not use any
++ * more buffers.
++ */
+ vdev->config->reset(vdev);
+
+ mutex_lock(&vsock->rx_lock);
+@@ -581,6 +624,7 @@ static void virtio_vsock_remove(struct virtio_device *vdev)
+ }
+ spin_unlock_bh(&vsock->send_pkt_list_lock);
+
++ /* Delete virtqueues and flush outstanding callbacks if any */
+ vdev->config->del_vqs(vdev);
+
+ mutex_unlock(&the_virtio_vsock_mutex);
+--
+2.25.1
+
--- /dev/null
+From 67f2ce2396bd0cb27084727b4366dbd59112abca Mon Sep 17 00:00:00 2001
+From: Sasha Levin <sashal@kernel.org>
+Date: Fri, 5 Jul 2019 13:04:52 +0200
+Subject: vsock/virtio: use RCU to avoid use-after-free on the_virtio_vsock
+
+From: Stefano Garzarella <sgarzare@redhat.com>
+
+[ Upstream commit 9c7a5582f5d720dc35cfcc42ccaded69f0642e4a ]
+
+Some callbacks used by the upper layers can run while we are in the
+.remove(). A potential use-after-free can happen, because we free
+the_virtio_vsock without knowing if the callbacks are over or not.
+
+To solve this issue we move the assignment of the_virtio_vsock at the
+end of .probe(), when we finished all the initialization, and at the
+beginning of .remove(), before to release resources.
+For the same reason, we do the same also for the vdev->priv.
+
+We use RCU to be sure that all callbacks that use the_virtio_vsock
+ended before freeing it. This is not required for callbacks that
+use vdev->priv, because after the vdev->config->del_vqs() we are sure
+that they are ended and will no longer be invoked.
+
+We also take the mutex during the .remove() to avoid that .probe() can
+run while we are resetting the device.
+
+Signed-off-by: Stefano Garzarella <sgarzare@redhat.com>
+Signed-off-by: David S. Miller <davem@davemloft.net>
+Signed-off-by: Sasha Levin <sashal@kernel.org>
+---
+ net/vmw_vsock/virtio_transport.c | 50 ++++++++++++++++++++------------
+ 1 file changed, 32 insertions(+), 18 deletions(-)
+
+diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
+index 0bd5a60f3bdeb..32ad7cfa5fa74 100644
+--- a/net/vmw_vsock/virtio_transport.c
++++ b/net/vmw_vsock/virtio_transport.c
+@@ -62,19 +62,22 @@ struct virtio_vsock {
+ u32 guest_cid;
+ };
+
+-static struct virtio_vsock *virtio_vsock_get(void)
+-{
+- return the_virtio_vsock;
+-}
+-
+ static u32 virtio_transport_get_local_cid(void)
+ {
+- struct virtio_vsock *vsock = virtio_vsock_get();
++ struct virtio_vsock *vsock;
++ u32 ret;
+
+- if (!vsock)
+- return VMADDR_CID_ANY;
++ rcu_read_lock();
++ vsock = rcu_dereference(the_virtio_vsock);
++ if (!vsock) {
++ ret = VMADDR_CID_ANY;
++ goto out_rcu;
++ }
+
+- return vsock->guest_cid;
++ ret = vsock->guest_cid;
++out_rcu:
++ rcu_read_unlock();
++ return ret;
+ }
+
+ static void
+@@ -156,10 +159,12 @@ virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
+ struct virtio_vsock *vsock;
+ int len = pkt->len;
+
+- vsock = virtio_vsock_get();
++ rcu_read_lock();
++ vsock = rcu_dereference(the_virtio_vsock);
+ if (!vsock) {
+ virtio_transport_free_pkt(pkt);
+- return -ENODEV;
++ len = -ENODEV;
++ goto out_rcu;
+ }
+
+ if (pkt->reply)
+@@ -170,6 +175,9 @@ virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
+ spin_unlock_bh(&vsock->send_pkt_list_lock);
+
+ queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
++
++out_rcu:
++ rcu_read_unlock();
+ return len;
+ }
+
+@@ -478,7 +486,8 @@ static int virtio_vsock_probe(struct virtio_device *vdev)
+ return ret;
+
+ /* Only one virtio-vsock device per guest is supported */
+- if (the_virtio_vsock) {
++ if (rcu_dereference_protected(the_virtio_vsock,
++ lockdep_is_held(&the_virtio_vsock_mutex))) {
+ ret = -EBUSY;
+ goto out;
+ }
+@@ -502,8 +511,6 @@ static int virtio_vsock_probe(struct virtio_device *vdev)
+ vsock->rx_buf_max_nr = 0;
+ atomic_set(&vsock->queued_replies, 0);
+
+- vdev->priv = vsock;
+- the_virtio_vsock = vsock;
+ mutex_init(&vsock->tx_lock);
+ mutex_init(&vsock->rx_lock);
+ mutex_init(&vsock->event_lock);
+@@ -522,6 +529,9 @@ static int virtio_vsock_probe(struct virtio_device *vdev)
+ virtio_vsock_event_fill(vsock);
+ mutex_unlock(&vsock->event_lock);
+
++ vdev->priv = vsock;
++ rcu_assign_pointer(the_virtio_vsock, vsock);
++
+ mutex_unlock(&the_virtio_vsock_mutex);
+ return 0;
+
+@@ -536,6 +546,12 @@ static void virtio_vsock_remove(struct virtio_device *vdev)
+ struct virtio_vsock *vsock = vdev->priv;
+ struct virtio_vsock_pkt *pkt;
+
++ mutex_lock(&the_virtio_vsock_mutex);
++
++ vdev->priv = NULL;
++ rcu_assign_pointer(the_virtio_vsock, NULL);
++ synchronize_rcu();
++
+ flush_work(&vsock->rx_work);
+ flush_work(&vsock->tx_work);
+ flush_work(&vsock->event_work);
+@@ -565,12 +581,10 @@ static void virtio_vsock_remove(struct virtio_device *vdev)
+ }
+ spin_unlock_bh(&vsock->send_pkt_list_lock);
+
+- mutex_lock(&the_virtio_vsock_mutex);
+- the_virtio_vsock = NULL;
+- mutex_unlock(&the_virtio_vsock_mutex);
+-
+ vdev->config->del_vqs(vdev);
+
++ mutex_unlock(&the_virtio_vsock_mutex);
++
+ kfree(vsock);
+ }
+
+--
+2.25.1
+