]> git.ipfire.org Git - thirdparty/linux.git/commitdiff
PCI: hv: Use vmbus_requestor to generate transaction IDs for VMbus hardening
authorAndrea Parri (Microsoft) <parri.andrea@gmail.com>
Tue, 19 Apr 2022 12:23:21 +0000 (14:23 +0200)
committerWei Liu <wei.liu@kernel.org>
Mon, 25 Apr 2022 15:51:12 +0000 (15:51 +0000)
Currently, pointers to guest memory are passed to Hyper-V as transaction
IDs in hv_pci.  In the face of errors or malicious behavior in Hyper-V,
hv_pci should not expose or trust the transaction IDs returned by
Hyper-V to be valid guest memory addresses.  Instead, use small integers
generated by vmbus_requestor as request (transaction) IDs.

Suggested-by: Michael Kelley <mikelley@microsoft.com>
Signed-off-by: Andrea Parri (Microsoft) <parri.andrea@gmail.com>
Reviewed-by: Michael Kelley <mikelley@microsoft.com>
Link: https://lore.kernel.org/r/20220419122325.10078-3-parri.andrea@gmail.com
Signed-off-by: Wei Liu <wei.liu@kernel.org>
drivers/pci/controller/pci-hyperv.c

index 1cbe24b92a385c450daf0a6f0e2165ea91a2dc23..c18e9b608bd6b5e3cd7a6b3e0f9a24a7267daf83 100644 (file)
@@ -91,6 +91,13 @@ static enum pci_protocol_version_t pci_protocol_versions[] = {
 /* space for 32bit serial number as string */
 #define SLOT_NAME_SIZE 11
 
+/*
+ * Size of requestor for VMbus; the value is based on the observation
+ * that having more than one request outstanding is 'rare', and so 64
+ * should be generous in ensuring that we don't ever run out.
+ */
+#define HV_PCI_RQSTOR_SIZE 64
+
 /*
  * Message Types
  */
@@ -1529,7 +1536,7 @@ static void hv_int_desc_free(struct hv_pci_dev *hpdev,
        int_pkt->wslot.slot = hpdev->desc.win_slot.slot;
        int_pkt->int_desc = *int_desc;
        vmbus_sendpacket(hpdev->hbus->hdev->channel, int_pkt, sizeof(*int_pkt),
-                        (unsigned long)&ctxt.pkt, VM_PKT_DATA_INBAND, 0);
+                        0, VM_PKT_DATA_INBAND, 0);
        kfree(int_desc);
 }
 
@@ -2661,7 +2668,7 @@ static void hv_eject_device_work(struct work_struct *work)
        ejct_pkt->message_type.type = PCI_EJECTION_COMPLETE;
        ejct_pkt->wslot.slot = hpdev->desc.win_slot.slot;
        vmbus_sendpacket(hbus->hdev->channel, ejct_pkt,
-                        sizeof(*ejct_pkt), (unsigned long)&ctxt.pkt,
+                        sizeof(*ejct_pkt), 0,
                         VM_PKT_DATA_INBAND, 0);
 
        /* For the get_pcichild() in hv_pci_eject_device() */
@@ -2708,8 +2715,9 @@ static void hv_pci_onchannelcallback(void *context)
        const int packet_size = 0x100;
        int ret;
        struct hv_pcibus_device *hbus = context;
+       struct vmbus_channel *chan = hbus->hdev->channel;
        u32 bytes_recvd;
-       u64 req_id;
+       u64 req_id, req_addr;
        struct vmpacket_descriptor *desc;
        unsigned char *buffer;
        int bufferlen = packet_size;
@@ -2727,8 +2735,8 @@ static void hv_pci_onchannelcallback(void *context)
                return;
 
        while (1) {
-               ret = vmbus_recvpacket_raw(hbus->hdev->channel, buffer,
-                                          bufferlen, &bytes_recvd, &req_id);
+               ret = vmbus_recvpacket_raw(chan, buffer, bufferlen,
+                                          &bytes_recvd, &req_id);
 
                if (ret == -ENOBUFS) {
                        kfree(buffer);
@@ -2755,11 +2763,14 @@ static void hv_pci_onchannelcallback(void *context)
                switch (desc->type) {
                case VM_PKT_COMP:
 
-                       /*
-                        * The host is trusted, and thus it's safe to interpret
-                        * this transaction ID as a pointer.
-                        */
-                       comp_packet = (struct pci_packet *)req_id;
+                       req_addr = chan->request_addr_callback(chan, req_id);
+                       if (req_addr == VMBUS_RQST_ERROR) {
+                               dev_err(&hbus->hdev->device,
+                                       "Invalid transaction ID %llx\n",
+                                       req_id);
+                               break;
+                       }
+                       comp_packet = (struct pci_packet *)req_addr;
                        response = (struct pci_response *)buffer;
                        comp_packet->completion_func(comp_packet->compl_ctxt,
                                                     response,
@@ -3440,6 +3451,10 @@ static int hv_pci_probe(struct hv_device *hdev,
                goto free_dom;
        }
 
+       hdev->channel->next_request_id_callback = vmbus_next_request_id;
+       hdev->channel->request_addr_callback = vmbus_request_addr;
+       hdev->channel->rqstor_size = HV_PCI_RQSTOR_SIZE;
+
        ret = vmbus_open(hdev->channel, pci_ring_size, pci_ring_size, NULL, 0,
                         hv_pci_onchannelcallback, hbus);
        if (ret)
@@ -3770,6 +3785,10 @@ static int hv_pci_resume(struct hv_device *hdev)
 
        hbus->state = hv_pcibus_init;
 
+       hdev->channel->next_request_id_callback = vmbus_next_request_id;
+       hdev->channel->request_addr_callback = vmbus_request_addr;
+       hdev->channel->rqstor_size = HV_PCI_RQSTOR_SIZE;
+
        ret = vmbus_open(hdev->channel, pci_ring_size, pci_ring_size, NULL, 0,
                         hv_pci_onchannelcallback, hbus);
        if (ret)