]> git.ipfire.org Git - thirdparty/kernel/linux.git/commitdiff
KVM: x86/mmu: Refactor low level rmap helpers to prep for walking w/o mmu_lock
authorSean Christopherson <seanjc@google.com>
Tue, 4 Feb 2025 00:40:35 +0000 (00:40 +0000)
committerSean Christopherson <seanjc@google.com>
Fri, 14 Feb 2025 15:17:33 +0000 (07:17 -0800)
Refactor the pte_list and rmap code to always read and write rmap_head->val
exactly once, e.g. by collecting changes in a local variable and then
propagating those changes back to rmap_head->val as appropriate.  This will
allow implementing a per-rmap rwlock (of sorts) by adding a LOCKED bit into
the rmap value alongside the MANY bit.

Signed-off-by: James Houghton <jthoughton@google.com>
Acked-by: Yu Zhao <yuzhao@google.com>
Reviewed-by: James Houghton <jthoughton@google.com>
Link: https://lore.kernel.org/r/20250204004038.1680123-9-jthoughton@google.com
Signed-off-by: Sean Christopherson <seanjc@google.com>
arch/x86/kvm/mmu/mmu.c

index 6af7eaa9feff0d2d8f09806a2b4d0ff65c5159ea..05b9fecdb25553dace139e9f7cb671ab073bfc45 100644 (file)
@@ -864,21 +864,24 @@ static struct kvm_memory_slot *gfn_to_memslot_dirty_bitmap(struct kvm_vcpu *vcpu
 static int pte_list_add(struct kvm_mmu_memory_cache *cache, u64 *spte,
                        struct kvm_rmap_head *rmap_head)
 {
+       unsigned long old_val, new_val;
        struct pte_list_desc *desc;
        int count = 0;
 
-       if (!rmap_head->val) {
-               rmap_head->val = (unsigned long)spte;
-       } else if (!(rmap_head->val & KVM_RMAP_MANY)) {
+       old_val = rmap_head->val;
+
+       if (!old_val) {
+               new_val = (unsigned long)spte;
+       } else if (!(old_val & KVM_RMAP_MANY)) {
                desc = kvm_mmu_memory_cache_alloc(cache);
-               desc->sptes[0] = (u64 *)rmap_head->val;
+               desc->sptes[0] = (u64 *)old_val;
                desc->sptes[1] = spte;
                desc->spte_count = 2;
                desc->tail_count = 0;
-               rmap_head->val = (unsigned long)desc | KVM_RMAP_MANY;
+               new_val = (unsigned long)desc | KVM_RMAP_MANY;
                ++count;
        } else {
-               desc = (struct pte_list_desc *)(rmap_head->val & ~KVM_RMAP_MANY);
+               desc = (struct pte_list_desc *)(old_val & ~KVM_RMAP_MANY);
                count = desc->tail_count + desc->spte_count;
 
                /*
@@ -887,21 +890,25 @@ static int pte_list_add(struct kvm_mmu_memory_cache *cache, u64 *spte,
                 */
                if (desc->spte_count == PTE_LIST_EXT) {
                        desc = kvm_mmu_memory_cache_alloc(cache);
-                       desc->more = (struct pte_list_desc *)(rmap_head->val & ~KVM_RMAP_MANY);
+                       desc->more = (struct pte_list_desc *)(old_val & ~KVM_RMAP_MANY);
                        desc->spte_count = 0;
                        desc->tail_count = count;
-                       rmap_head->val = (unsigned long)desc | KVM_RMAP_MANY;
+                       new_val = (unsigned long)desc | KVM_RMAP_MANY;
+               } else {
+                       new_val = old_val;
                }
                desc->sptes[desc->spte_count++] = spte;
        }
+
+       rmap_head->val = new_val;
+
        return count;
 }
 
-static void pte_list_desc_remove_entry(struct kvm *kvm,
-                                      struct kvm_rmap_head *rmap_head,
+static void pte_list_desc_remove_entry(struct kvm *kvm, unsigned long *rmap_val,
                                       struct pte_list_desc *desc, int i)
 {
-       struct pte_list_desc *head_desc = (struct pte_list_desc *)(rmap_head->val & ~KVM_RMAP_MANY);
+       struct pte_list_desc *head_desc = (struct pte_list_desc *)(*rmap_val & ~KVM_RMAP_MANY);
        int j = head_desc->spte_count - 1;
 
        /*
@@ -928,9 +935,9 @@ static void pte_list_desc_remove_entry(struct kvm *kvm,
         * head at the next descriptor, i.e. the new head.
         */
        if (!head_desc->more)
-               rmap_head->val = 0;
+               *rmap_val = 0;
        else
-               rmap_head->val = (unsigned long)head_desc->more | KVM_RMAP_MANY;
+               *rmap_val = (unsigned long)head_desc->more | KVM_RMAP_MANY;
        mmu_free_pte_list_desc(head_desc);
 }
 
@@ -938,24 +945,26 @@ static void pte_list_remove(struct kvm *kvm, u64 *spte,
                            struct kvm_rmap_head *rmap_head)
 {
        struct pte_list_desc *desc;
+       unsigned long rmap_val;
        int i;
 
-       if (KVM_BUG_ON_DATA_CORRUPTION(!rmap_head->val, kvm))
-               return;
+       rmap_val = rmap_head->val;
+       if (KVM_BUG_ON_DATA_CORRUPTION(!rmap_val, kvm))
+               goto out;
 
-       if (!(rmap_head->val & KVM_RMAP_MANY)) {
-               if (KVM_BUG_ON_DATA_CORRUPTION((u64 *)rmap_head->val != spte, kvm))
-                       return;
+       if (!(rmap_val & KVM_RMAP_MANY)) {
+               if (KVM_BUG_ON_DATA_CORRUPTION((u64 *)rmap_val != spte, kvm))
+                       goto out;
 
-               rmap_head->val = 0;
+               rmap_val = 0;
        } else {
-               desc = (struct pte_list_desc *)(rmap_head->val & ~KVM_RMAP_MANY);
+               desc = (struct pte_list_desc *)(rmap_val & ~KVM_RMAP_MANY);
                while (desc) {
                        for (i = 0; i < desc->spte_count; ++i) {
                                if (desc->sptes[i] == spte) {
-                                       pte_list_desc_remove_entry(kvm, rmap_head,
+                                       pte_list_desc_remove_entry(kvm, &rmap_val,
                                                                   desc, i);
-                                       return;
+                                       goto out;
                                }
                        }
                        desc = desc->more;
@@ -963,6 +972,9 @@ static void pte_list_remove(struct kvm *kvm, u64 *spte,
 
                KVM_BUG_ON_DATA_CORRUPTION(true, kvm);
        }
+
+out:
+       rmap_head->val = rmap_val;
 }
 
 static void kvm_zap_one_rmap_spte(struct kvm *kvm,
@@ -977,17 +989,19 @@ static bool kvm_zap_all_rmap_sptes(struct kvm *kvm,
                                   struct kvm_rmap_head *rmap_head)
 {
        struct pte_list_desc *desc, *next;
+       unsigned long rmap_val;
        int i;
 
-       if (!rmap_head->val)
+       rmap_val = rmap_head->val;
+       if (!rmap_val)
                return false;
 
-       if (!(rmap_head->val & KVM_RMAP_MANY)) {
-               mmu_spte_clear_track_bits(kvm, (u64 *)rmap_head->val);
+       if (!(rmap_val & KVM_RMAP_MANY)) {
+               mmu_spte_clear_track_bits(kvm, (u64 *)rmap_val);
                goto out;
        }
 
-       desc = (struct pte_list_desc *)(rmap_head->val & ~KVM_RMAP_MANY);
+       desc = (struct pte_list_desc *)(rmap_val & ~KVM_RMAP_MANY);
 
        for (; desc; desc = next) {
                for (i = 0; i < desc->spte_count; i++)
@@ -1003,14 +1017,15 @@ out:
 
 unsigned int pte_list_count(struct kvm_rmap_head *rmap_head)
 {
+       unsigned long rmap_val = rmap_head->val;
        struct pte_list_desc *desc;
 
-       if (!rmap_head->val)
+       if (!rmap_val)
                return 0;
-       else if (!(rmap_head->val & KVM_RMAP_MANY))
+       else if (!(rmap_val & KVM_RMAP_MANY))
                return 1;
 
-       desc = (struct pte_list_desc *)(rmap_head->val & ~KVM_RMAP_MANY);
+       desc = (struct pte_list_desc *)(rmap_val & ~KVM_RMAP_MANY);
        return desc->tail_count + desc->spte_count;
 }
 
@@ -1053,6 +1068,7 @@ static void rmap_remove(struct kvm *kvm, u64 *spte)
  */
 struct rmap_iterator {
        /* private fields */
+       struct rmap_head *head;
        struct pte_list_desc *desc;     /* holds the sptep if not NULL */
        int pos;                        /* index of the sptep */
 };
@@ -1067,18 +1083,19 @@ struct rmap_iterator {
 static u64 *rmap_get_first(struct kvm_rmap_head *rmap_head,
                           struct rmap_iterator *iter)
 {
+       unsigned long rmap_val = rmap_head->val;
        u64 *sptep;
 
-       if (!rmap_head->val)
+       if (!rmap_val)
                return NULL;
 
-       if (!(rmap_head->val & KVM_RMAP_MANY)) {
+       if (!(rmap_val & KVM_RMAP_MANY)) {
                iter->desc = NULL;
-               sptep = (u64 *)rmap_head->val;
+               sptep = (u64 *)rmap_val;
                goto out;
        }
 
-       iter->desc = (struct pte_list_desc *)(rmap_head->val & ~KVM_RMAP_MANY);
+       iter->desc = (struct pte_list_desc *)(rmap_val & ~KVM_RMAP_MANY);
        iter->pos = 0;
        sptep = iter->desc->sptes[iter->pos];
 out: