static int mmu_shrink(struct shrinker *shrink, struct shrink_control *sc)
 {
        struct kvm *kvm;
-       struct kvm *kvm_freed = NULL;
        int nr_to_scan = sc->nr_to_scan;
 
        if (nr_to_scan == 0)
                int idx;
                LIST_HEAD(invalid_list);
 
+               /*
+                * n_used_mmu_pages is accessed without holding kvm->mmu_lock
+                * here. We may skip a VM instance errorneosly, but we do not
+                * want to shrink a VM that only started to populate its MMU
+                * anyway.
+                */
+               if (kvm->arch.n_used_mmu_pages > 0) {
+                       if (!nr_to_scan--)
+                               break;
+                       continue;
+               }
+
                idx = srcu_read_lock(&kvm->srcu);
                spin_lock(&kvm->mmu_lock);
-               if (!kvm_freed && nr_to_scan > 0 &&
-                   kvm->arch.n_used_mmu_pages > 0) {
-                       kvm_mmu_remove_some_alloc_mmu_pages(kvm,
-                                                           &invalid_list);
-                       kvm_freed = kvm;
-               }
-               nr_to_scan--;
 
+               kvm_mmu_remove_some_alloc_mmu_pages(kvm, &invalid_list);
                kvm_mmu_commit_zap_page(kvm, &invalid_list);
+
                spin_unlock(&kvm->mmu_lock);
                srcu_read_unlock(&kvm->srcu, idx);
+
+               list_move_tail(&kvm->vm_list, &vm_list);
+               break;
        }
-       if (kvm_freed)
-               list_move_tail(&kvm_freed->vm_list, &vm_list);
 
        raw_spin_unlock(&kvm_lock);