--- /dev/null
+From d7fb2326c5a04837cdfd7a23fd993593fe7e2aa1 Mon Sep 17 00:00:00 2001
+From: Sasha Levin <sashal@kernel.org>
+Date: Mon, 11 Nov 2019 14:54:23 -0800
+Subject: KVM: x86: introduce is_pae_paging
+
+From: Paolo Bonzini <pbonzini@redhat.com>
+
+[ Upstream commit bf03d4f9334728bf7c8ffc7de787df48abd6340e ]
+
+Checking for 32-bit PAE is quite common around code that fiddles with
+the PDPTRs. Add a function to compress all checks into a single
+invocation.
+
+Moving to the common helper also fixes a subtle bug in kvm_set_cr3()
+where it fails to check is_long_mode() and results in KVM incorrectly
+attempting to load PDPTRs for a 64-bit guest.
+
+Reviewed-by: Sean Christopherson <sean.j.christopherson@intel.com>
+Signed-off-by: Paolo Bonzini <pbonzini@redhat.com>
+[sean: backport to 4.x; handle vmx.c split in 5.x, call out the bugfix]
+Signed-off-by: Sean Christopherson <sean.j.christopherson@intel.com>
+Acked-by: Paolo Bonzini <pbonzini@redhat.com>
+Tested-by: Thomas Lamprecht <t.lamprecht@proxmox.com>
+Signed-off-by: Sasha Levin <sashal@kernel.org>
+---
+ arch/x86/kvm/vmx.c | 7 +++----
+ arch/x86/kvm/x86.c | 8 ++++----
+ arch/x86/kvm/x86.h | 5 +++++
+ 3 files changed, 12 insertions(+), 8 deletions(-)
+
+diff --git a/arch/x86/kvm/vmx.c b/arch/x86/kvm/vmx.c
+index 4eda2a9c234a6..1ab4bb3d6a040 100644
+--- a/arch/x86/kvm/vmx.c
++++ b/arch/x86/kvm/vmx.c
+@@ -5173,7 +5173,7 @@ static void ept_load_pdptrs(struct kvm_vcpu *vcpu)
+ (unsigned long *)&vcpu->arch.regs_dirty))
+ return;
+
+- if (is_paging(vcpu) && is_pae(vcpu) && !is_long_mode(vcpu)) {
++ if (is_pae_paging(vcpu)) {
+ vmcs_write64(GUEST_PDPTR0, mmu->pdptrs[0]);
+ vmcs_write64(GUEST_PDPTR1, mmu->pdptrs[1]);
+ vmcs_write64(GUEST_PDPTR2, mmu->pdptrs[2]);
+@@ -5185,7 +5185,7 @@ static void ept_save_pdptrs(struct kvm_vcpu *vcpu)
+ {
+ struct kvm_mmu *mmu = vcpu->arch.walk_mmu;
+
+- if (is_paging(vcpu) && is_pae(vcpu) && !is_long_mode(vcpu)) {
++ if (is_pae_paging(vcpu)) {
+ mmu->pdptrs[0] = vmcs_read64(GUEST_PDPTR0);
+ mmu->pdptrs[1] = vmcs_read64(GUEST_PDPTR1);
+ mmu->pdptrs[2] = vmcs_read64(GUEST_PDPTR2);
+@@ -12013,8 +12013,7 @@ static int nested_vmx_load_cr3(struct kvm_vcpu *vcpu, unsigned long cr3, bool ne
+ * If PAE paging and EPT are both on, CR3 is not used by the CPU and
+ * must not be dereferenced.
+ */
+- if (!is_long_mode(vcpu) && is_pae(vcpu) && is_paging(vcpu) &&
+- !nested_ept) {
++ if (is_pae_paging(vcpu) && !nested_ept) {
+ if (!load_pdptrs(vcpu, vcpu->arch.walk_mmu, cr3)) {
+ *entry_failure_code = ENTRY_FAIL_PDPTE;
+ return 1;
+diff --git a/arch/x86/kvm/x86.c b/arch/x86/kvm/x86.c
+index e536503ac7881..6cf8af022b21d 100644
+--- a/arch/x86/kvm/x86.c
++++ b/arch/x86/kvm/x86.c
+@@ -634,7 +634,7 @@ bool pdptrs_changed(struct kvm_vcpu *vcpu)
+ gfn_t gfn;
+ int r;
+
+- if (is_long_mode(vcpu) || !is_pae(vcpu) || !is_paging(vcpu))
++ if (!is_pae_paging(vcpu))
+ return false;
+
+ if (!test_bit(VCPU_EXREG_PDPTR,
+@@ -885,8 +885,8 @@ int kvm_set_cr3(struct kvm_vcpu *vcpu, unsigned long cr3)
+ if (is_long_mode(vcpu) &&
+ (cr3 & rsvd_bits(cpuid_maxphyaddr(vcpu), 63)))
+ return 1;
+- else if (is_pae(vcpu) && is_paging(vcpu) &&
+- !load_pdptrs(vcpu, vcpu->arch.walk_mmu, cr3))
++ else if (is_pae_paging(vcpu) &&
++ !load_pdptrs(vcpu, vcpu->arch.walk_mmu, cr3))
+ return 1;
+
+ kvm_mmu_new_cr3(vcpu, cr3, skip_tlb_flush);
+@@ -8348,7 +8348,7 @@ static int __set_sregs(struct kvm_vcpu *vcpu, struct kvm_sregs *sregs)
+ kvm_update_cpuid(vcpu);
+
+ idx = srcu_read_lock(&vcpu->kvm->srcu);
+- if (!is_long_mode(vcpu) && is_pae(vcpu) && is_paging(vcpu)) {
++ if (is_pae_paging(vcpu)) {
+ load_pdptrs(vcpu, vcpu->arch.walk_mmu, kvm_read_cr3(vcpu));
+ mmu_reset_needed = 1;
+ }
+diff --git a/arch/x86/kvm/x86.h b/arch/x86/kvm/x86.h
+index 3a91ea760f073..608e5f8c5d0a5 100644
+--- a/arch/x86/kvm/x86.h
++++ b/arch/x86/kvm/x86.h
+@@ -139,6 +139,11 @@ static inline int is_paging(struct kvm_vcpu *vcpu)
+ return likely(kvm_read_cr0_bits(vcpu, X86_CR0_PG));
+ }
+
++static inline bool is_pae_paging(struct kvm_vcpu *vcpu)
++{
++ return !is_long_mode(vcpu) && is_pae(vcpu) && is_paging(vcpu);
++}
++
+ static inline u32 bit(int bitno)
+ {
+ return 1 << (bitno & 31);
+--
+2.20.1
+