]> git.ipfire.org Git - thirdparty/qemu.git/commitdiff
rust/hpet: Borrow BqlRefCell<HPETRegisters> at top level
authorZhao Liu <zhao1.liu@intel.com>
Fri, 21 Nov 2025 17:40:19 +0000 (18:40 +0100)
committerPaolo Bonzini <pbonzini@redhat.com>
Sat, 27 Dec 2025 09:11:10 +0000 (10:11 +0100)
Lockless IO requires to lock the registers during MMIO access. So it's
necessary to get (or borrow) registers data at top level, and not to
borrow again in child function calls.

Change the argument types from BqlRefCell<HPETRegisters> to
&HPETRegisters/&mut HPETRegisters in child methods, and do borrow the
data once at top level.

This allows BqlRefCell<HPETRegisters> to be directly replaced with
Mutex<HPETRegisters> in subsequent steps without causing lock reentrancy
issues.

Note, passing reference instead of BqlRef/BqlRefMut because BqlRefMut
cannot be re-borrowed as BqlRef, though BqlRef/BqlRefMut themselves play
as the "guard". Passing reference is directly and the extra
bql::is_locked check could help to consolidate safety guarantee.

Signed-off-by: Zhao Liu <zhao1.liu@intel.com>
Link: https://lore.kernel.org/r/20251113051937.4017675-19-zhao1.liu@intel.com
Signed-off-by: Paolo Bonzini <pbonzini@redhat.com>
rust/hw/timer/hpet/src/device.rs

index f9cdced54067c5d95911a1ccb6ee4d5f8f24f8fa..916ae587b0817e491273c55803b22d60e7806131 100644 (file)
@@ -174,7 +174,7 @@ fn timer_handler(timer_cell: &BqlRefCell<HPETTimer>) {
     let mut t = timer_cell.borrow_mut();
     // SFAETY: state field is valid after timer initialization.
     let hpet_regs = &mut unsafe { t.state.as_mut() }.regs;
-    t.callback(hpet_regs)
+    t.callback(&mut hpet_regs.borrow_mut())
 }
 
 #[repr(C)]
@@ -280,20 +280,15 @@ impl HPETTimer {
         unsafe { self.state.as_ref() }
     }
 
-    fn is_int_active(&self, hpet_regs: &BqlRefCell<HPETRegisters>) -> bool {
-        hpet_regs.borrow().is_timer_int_active(self.index.into())
+    fn is_int_active(&self, regs: &HPETRegisters) -> bool {
+        regs.is_timer_int_active(self.index.into())
     }
 
     /// calculate next value of the general counter that matches the
     /// target (either entirely, or the low 32-bit only depending on
     /// the timer mode).
-    fn calculate_cmp64(
-        &self,
-        hpet_regs: &BqlRefCell<HPETRegisters>,
-        cur_tick: u64,
-        target: u64,
-    ) -> u64 {
-        let tn_regs = &hpet_regs.borrow().tn_regs[self.index as usize];
+    fn calculate_cmp64(&self, regs: &HPETRegisters, cur_tick: u64, target: u64) -> u64 {
+        let tn_regs = &regs.tn_regs[self.index as usize];
         if tn_regs.is_32bit_mod() {
             let mut result: u64 = cur_tick.deposit(0, 32, target);
             if result < cur_tick {
@@ -357,8 +352,7 @@ impl HPETTimer {
         }
     }
 
-    fn update_irq(&self, hpet_regs: &BqlRefCell<HPETRegisters>, set: bool) {
-        let mut regs = hpet_regs.borrow_mut();
+    fn update_irq(&self, regs: &mut HPETRegisters, set: bool) {
         // If Timer N Interrupt Enable bit is 0, "the timer will
         // still operate and generate appropriate status bits, but
         // will not cause an interrupt"
@@ -367,11 +361,11 @@ impl HPETTimer {
             1,
             u64::from(set && regs.tn_regs[self.index as usize].is_int_level_triggered()),
         );
-        self.set_irq(&regs, set);
+        self.set_irq(regs, set);
     }
 
-    fn arm_timer(&mut self, hpet_regs: &BqlRefCell<HPETRegisters>, tick: u64) {
-        let tn_regs = &hpet_regs.borrow().tn_regs[self.index as usize];
+    fn arm_timer(&mut self, regs: &HPETRegisters, tick: u64) {
+        let tn_regs = &regs.tn_regs[self.index as usize];
         let mut ns = self.get_state().get_ns(tick);
 
         // Clamp period to reasonable min value (1 us)
@@ -383,45 +377,45 @@ impl HPETTimer {
         self.qemu_timer.modify(self.last);
     }
 
-    fn set_timer(&mut self, hpet_regs: &BqlRefCell<HPETRegisters>) {
-        let tn_regs = &hpet_regs.borrow().tn_regs[self.index as usize];
+    fn set_timer(&mut self, regs: &HPETRegisters) {
+        let tn_regs = &regs.tn_regs[self.index as usize];
         let cur_tick: u64 = self.get_state().get_ticks();
 
         self.wrap_flag = 0;
-        self.cmp64 = self.calculate_cmp64(hpet_regs, cur_tick, tn_regs.cmp);
+        self.cmp64 = self.calculate_cmp64(regs, cur_tick, tn_regs.cmp);
         if tn_regs.is_32bit_mod() {
             // HPET spec says in one-shot 32-bit mode, generate an interrupt when
             // counter wraps in addition to an interrupt with comparator match.
             if !tn_regs.is_periodic() && self.cmp64 > hpet_next_wrap(cur_tick) {
                 self.wrap_flag = 1;
-                self.arm_timer(hpet_regs, hpet_next_wrap(cur_tick));
+                self.arm_timer(regs, hpet_next_wrap(cur_tick));
                 return;
             }
         }
-        self.arm_timer(hpet_regs, self.cmp64);
+        self.arm_timer(regs, self.cmp64);
     }
 
-    fn del_timer(&self, hpet_regs: &BqlRefCell<HPETRegisters>) {
+    fn del_timer(&self, regs: &mut HPETRegisters) {
         // Just remove the timer from the timer_list without destroying
         // this timer instance.
         self.qemu_timer.delete();
 
-        if self.is_int_active(hpet_regs) {
+        if self.is_int_active(regs) {
             // For level-triggered interrupt, this leaves interrupt status
             // register set but lowers irq.
-            self.update_irq(hpet_regs, true);
+            self.update_irq(regs, true);
         }
     }
 
     fn prepare_tn_cfg_reg_new(
         &self,
-        hpet_regs: &BqlRefCell<HPETRegisters>,
+        regs: &mut HPETRegisters,
         shift: u32,
         len: u32,
         val: u64,
     ) -> (u64, u64) {
         trace::trace_hpet_ram_write_tn_cfg((shift / 8).try_into().unwrap());
-        let tn_regs = &hpet_regs.borrow().tn_regs[self.index as usize];
+        let tn_regs = &regs.tn_regs[self.index as usize];
         let old_val: u64 = tn_regs.config;
         let mut new_val: u64 = old_val.deposit(shift, len, val);
         new_val = hpet_fixup_reg(new_val, old_val, HPET_TN_CFG_WRITE_MASK);
@@ -430,50 +424,38 @@ impl HPETTimer {
         if deactivating_bit(old_val, new_val, HPET_TN_CFG_INT_TYPE_SHIFT) {
             // Do this before changing timer.regs.config; otherwise, if
             // HPET_TN_FSB is set, update_irq will not lower the qemu_irq.
-            self.update_irq(hpet_regs, false);
+            self.update_irq(regs, false);
         }
 
         (new_val, old_val)
     }
 
     /// Configuration and Capability Register
-    fn set_tn_cfg_reg(
-        &mut self,
-        hpet_regs: &BqlRefCell<HPETRegisters>,
-        shift: u32,
-        len: u32,
-        val: u64,
-    ) {
+    fn set_tn_cfg_reg(&mut self, regs: &mut HPETRegisters, shift: u32, len: u32, val: u64) {
         // Factor out a prepare_tn_cfg_reg_new() to better handle immutable scope.
-        let (new_val, old_val) = self.prepare_tn_cfg_reg_new(hpet_regs, shift, len, val);
-        let tn_regs = &mut hpet_regs.borrow_mut().tn_regs[self.index as usize];
-        tn_regs.config = new_val;
+        let (new_val, old_val) = self.prepare_tn_cfg_reg_new(regs, shift, len, val);
+        regs.tn_regs[self.index as usize].config = new_val;
 
         if activating_bit(old_val, new_val, HPET_TN_CFG_INT_ENABLE_SHIFT)
-            && self.is_int_active(hpet_regs)
+            && self.is_int_active(regs)
         {
-            self.update_irq(hpet_regs, true);
+            self.update_irq(regs, true);
         }
 
+        let tn_regs = &mut regs.tn_regs[self.index as usize];
         if tn_regs.is_32bit_mod() {
             tn_regs.cmp = u64::from(tn_regs.cmp as u32); // truncate!
             self.period = u64::from(self.period as u32); // truncate!
         }
 
-        if hpet_regs.borrow().is_hpet_enabled() {
-            self.set_timer(hpet_regs);
+        if regs.is_hpet_enabled() {
+            self.set_timer(regs);
         }
     }
 
     /// Comparator Value Register
-    fn set_tn_cmp_reg(
-        &mut self,
-        hpet_regs: &BqlRefCell<HPETRegisters>,
-        shift: u32,
-        len: u32,
-        val: u64,
-    ) {
-        let tn_regs = &mut hpet_regs.borrow_mut().tn_regs[self.index as usize];
+    fn set_tn_cmp_reg(&mut self, regs: &mut HPETRegisters, shift: u32, len: u32, val: u64) {
+        let tn_regs = &mut regs.tn_regs[self.index as usize];
         let mut length = len;
         let mut value = val;
 
@@ -498,27 +480,21 @@ impl HPETTimer {
         }
 
         tn_regs.clear_valset();
-        if hpet_regs.borrow().is_hpet_enabled() {
-            self.set_timer(hpet_regs);
+        if regs.is_hpet_enabled() {
+            self.set_timer(regs);
         }
     }
 
     /// FSB Interrupt Route Register
-    fn set_tn_fsb_route_reg(
-        &self,
-        hpet_regs: &BqlRefCell<HPETRegisters>,
-        shift: u32,
-        len: u32,
-        val: u64,
-    ) {
-        let tn_regs = &mut hpet_regs.borrow_mut().tn_regs[self.index as usize];
+    fn set_tn_fsb_route_reg(&self, regs: &mut HPETRegisters, shift: u32, len: u32, val: u64) {
+        let tn_regs = &mut regs.tn_regs[self.index as usize];
         tn_regs.fsb = tn_regs.fsb.deposit(shift, len, val);
     }
 
-    fn reset(&mut self, hpet_regs: &BqlRefCell<HPETRegisters>) {
-        self.del_timer(hpet_regs);
+    fn reset(&mut self, regs: &mut HPETRegisters) {
+        self.del_timer(regs);
 
-        let tn_regs = &mut hpet_regs.borrow_mut().tn_regs[self.index as usize];
+        let tn_regs = &mut regs.tn_regs[self.index as usize];
         tn_regs.cmp = u64::MAX; // Comparator Match Registers reset to all 1's.
         tn_regs.config = (1 << HPET_TN_CFG_PERIODIC_CAP_SHIFT) | (1 << HPET_TN_CFG_SIZE_CAP_SHIFT);
         if self.get_state().has_msi_flag() {
@@ -532,8 +508,8 @@ impl HPETTimer {
     }
 
     /// timer expiration callback
-    fn callback(&mut self, hpet_regs: &BqlRefCell<HPETRegisters>) {
-        let tn_regs = &mut hpet_regs.borrow_mut().tn_regs[self.index as usize];
+    fn callback(&mut self, regs: &mut HPETRegisters) {
+        let tn_regs = &mut regs.tn_regs[self.index as usize];
         let period: u64 = self.period;
         let cur_tick: u64 = self.get_state().get_ticks();
 
@@ -546,16 +522,16 @@ impl HPETTimer {
             } else {
                 tn_regs.cmp = self.cmp64;
             }
-            self.arm_timer(hpet_regs, self.cmp64);
+            self.arm_timer(regs, self.cmp64);
         } else if self.wrap_flag != 0 {
             self.wrap_flag = 0;
-            self.arm_timer(hpet_regs, self.cmp64);
+            self.arm_timer(regs, self.cmp64);
         }
-        self.update_irq(hpet_regs, true);
+        self.update_irq(regs, true);
     }
 
-    fn read(&self, target: TimerRegister, hpet_regs: &BqlRefCell<HPETRegisters>) -> u64 {
-        let tn_regs = &hpet_regs.borrow().tn_regs[self.index as usize];
+    fn read(&self, target: TimerRegister, regs: &HPETRegisters) -> u64 {
+        let tn_regs = &regs.tn_regs[self.index as usize];
 
         use TimerRegister::*;
         match target {
@@ -568,7 +544,7 @@ impl HPETTimer {
     fn write(
         &mut self,
         target: TimerRegister,
-        hpet_regs: &BqlRefCell<HPETRegisters>,
+        regs: &mut HPETRegisters,
         value: u64,
         shift: u32,
         len: u32,
@@ -577,9 +553,9 @@ impl HPETTimer {
 
         trace::trace_hpet_ram_write_timer_id(self.index);
         match target {
-            CFG => self.set_tn_cfg_reg(hpet_regs, shift, len, value),
-            CMP => self.set_tn_cmp_reg(hpet_regs, shift, len, value),
-            ROUTE => self.set_tn_fsb_route_reg(hpet_regs, shift, len, value),
+            CFG => self.set_tn_cfg_reg(regs, shift, len, value),
+            CMP => self.set_tn_cmp_reg(regs, shift, len, value),
+            ROUTE => self.set_tn_fsb_route_reg(regs, shift, len, value),
         }
     }
 }
@@ -702,22 +678,22 @@ impl HPETState {
     }
 
     /// General Configuration Register
-    fn set_cfg_reg(&self, regs: &BqlRefCell<HPETRegisters>, shift: u32, len: u32, val: u64) {
-        let old_val = regs.borrow().config;
+    fn set_cfg_reg(&self, regs: &mut HPETRegisters, shift: u32, len: u32, val: u64) {
+        let old_val = regs.config;
         let mut new_val = old_val.deposit(shift, len, val);
 
         new_val = hpet_fixup_reg(new_val, old_val, HPET_CFG_WRITE_MASK);
-        regs.borrow_mut().config = new_val;
+        regs.config = new_val;
 
         if activating_bit(old_val, new_val, HPET_CFG_ENABLE_SHIFT) {
             // Enable main counter and interrupt generation.
             self.hpet_offset
-                .set(ticks_to_ns(regs.borrow().counter) - CLOCK_VIRTUAL.get_ns());
+                .set(ticks_to_ns(regs.counter) - CLOCK_VIRTUAL.get_ns());
 
             for timer in self.timers.iter().take(self.num_timers) {
                 let mut t = timer.borrow_mut();
                 let id = t.index as usize;
-                let tn_regs = &regs.borrow().tn_regs[id];
+                let tn_regs = &regs.tn_regs[id];
 
                 if tn_regs.is_int_enabled() && t.is_int_active(regs) {
                     t.update_irq(regs, true);
@@ -726,7 +702,7 @@ impl HPETState {
             }
         } else if deactivating_bit(old_val, new_val, HPET_CFG_ENABLE_SHIFT) {
             // Halt main counter and disable interrupt generation.
-            regs.borrow_mut().counter = self.get_ticks();
+            regs.counter = self.get_ticks();
 
             for timer in self.timers.iter().take(self.num_timers) {
                 timer.borrow().del_timer(regs);
@@ -747,15 +723,9 @@ impl HPETState {
     }
 
     /// General Interrupt Status Register: Read/Write Clear
-    fn set_int_status_reg(
-        &self,
-        regs: &BqlRefCell<HPETRegisters>,
-        shift: u32,
-        _len: u32,
-        val: u64,
-    ) {
+    fn set_int_status_reg(&self, regs: &mut HPETRegisters, shift: u32, _len: u32, val: u64) {
         let new_val = val << shift;
-        let cleared = new_val & regs.borrow().int_status;
+        let cleared = new_val & regs.int_status;
 
         for (index, timer) in self.timers.iter().take(self.num_timers).enumerate() {
             if cleared & (1 << index) != 0 {
@@ -765,8 +735,7 @@ impl HPETState {
     }
 
     /// Main Counter Value Register
-    fn set_counter_reg(&self, regs: &BqlRefCell<HPETRegisters>, shift: u32, len: u32, val: u64) {
-        let mut regs = regs.borrow_mut();
+    fn set_counter_reg(&self, regs: &mut HPETRegisters, shift: u32, len: u32, val: u64) {
         if regs.is_hpet_enabled() {
             // HPET spec says that writes to this register should only be
             // done while the counter is halted. So this is an undefined
@@ -844,23 +813,23 @@ impl HPETState {
     }
 
     fn reset_hold(&self, _type: ResetType) {
+        let mut regs = self.regs.borrow_mut();
         for timer in self.timers.iter().take(self.num_timers) {
-            timer.borrow_mut().reset(&self.regs);
+            timer.borrow_mut().reset(&mut regs);
         }
 
+        regs.counter = 0;
+        regs.config = 0;
+        HPETFwConfig::update_hpet_cfg(
+            self.hpet_id.get(),
+            regs.capability as u32,
+            self.mmio_addr(0).unwrap(),
+        );
+
         // pit_enabled.set(true) will call irq handler and access regs
         // again. We cannot borrow BqlRefCell twice at once. Minimize the
         // scope of regs to ensure it will be dropped before irq callback.
-        {
-            let mut regs = self.regs.borrow_mut();
-            regs.counter = 0;
-            regs.config = 0;
-            HPETFwConfig::update_hpet_cfg(
-                self.hpet_id.get(),
-                regs.capability as u32,
-                self.mmio_addr(0).unwrap(),
-            );
-        }
+        drop(regs);
 
         self.pit_enabled.set(true);
         self.hpet_offset.set(0);
@@ -897,17 +866,16 @@ impl HPETState {
         trace::trace_hpet_ram_read(addr);
 
         let HPETAddrDecode { shift, target, .. } = self.decode(addr, size);
-        let regs = &self.regs;
+        let regs = &self.regs.borrow();
 
         use DecodedRegister::*;
         use GlobalRegister::*;
         (match target {
             Timer(timer, tn_target) => timer.borrow_mut().read(tn_target, regs),
-            Global(CAP) => regs.borrow().capability, /* including HPET_PERIOD 0x004 */
-            Global(CFG) => regs.borrow().config,
-            Global(INT_STATUS) => regs.borrow().int_status,
+            Global(CAP) => regs.capability, /* including HPET_PERIOD 0x004 */
+            Global(CFG) => regs.config,
+            Global(INT_STATUS) => regs.int_status,
             Global(COUNTER) => {
-                let regs = regs.borrow();
                 let cur_tick = if regs.is_hpet_enabled() {
                     self.get_ticks()
                 } else {
@@ -927,18 +895,20 @@ impl HPETState {
 
     fn write(&self, addr: hwaddr, value: u64, size: u32) {
         let HPETAddrDecode { shift, len, target } = self.decode(addr, size);
-        let regs = &self.regs;
+        let mut regs = self.regs.borrow_mut();
 
         trace::trace_hpet_ram_write(addr, value);
 
         use DecodedRegister::*;
         use GlobalRegister::*;
         match target {
-            Timer(timer, tn_target) => timer.borrow_mut().write(tn_target, regs, value, shift, len),
+            Timer(timer, tn_target) => timer
+                .borrow_mut()
+                .write(tn_target, &mut regs, value, shift, len),
             Global(CAP) => {} // General Capabilities and ID Register: Read Only
-            Global(CFG) => self.set_cfg_reg(regs, shift, len, value),
-            Global(INT_STATUS) => self.set_int_status_reg(regs, shift, len, value),
-            Global(COUNTER) => self.set_counter_reg(regs, shift, len, value),
+            Global(CFG) => self.set_cfg_reg(&mut regs, shift, len, value),
+            Global(INT_STATUS) => self.set_int_status_reg(&mut regs, shift, len, value),
+            Global(COUNTER) => self.set_counter_reg(&mut regs, shift, len, value),
             Unknown(_) => trace::trace_hpet_ram_write_invalid(),
         }
     }
@@ -966,7 +936,7 @@ impl HPETState {
             let cnt = regs.counter;
             let cmp = regs.tn_regs[t.index as usize].cmp;
 
-            t.cmp64 = t.calculate_cmp64(&self.regs, cnt, cmp);
+            t.cmp64 = t.calculate_cmp64(&regs, cnt, cmp);
             t.last = CLOCK_VIRTUAL.get_ns() - NANOSECONDS_PER_SECOND;
         }