From 3f2a5ba784b808109cac0aac921213e43143a216 Mon Sep 17 00:00:00 2001 From: Lyude Paul Date: Thu, 21 Aug 2025 15:32:44 -0400 Subject: [PATCH] rust: hrtimer: Add HrTimerCallbackContext and ::forward() With Linux's hrtimer API, there's a number of methods that can only be called in two situations: * When we have exclusive access to the hrtimer and it is not currently active * When we're within the context of an hrtimer callback context This commit handles the second situation and implements hrtimer_forward() support in the context of a timer callback. We do this by introducing a HrTimerCallbackContext type which is provided to users during the RawHrTimerCallback::run() callback, and then add a forward() function to the type. Signed-off-by: Lyude Paul Reviewed-by: Daniel Almeida Reviewed-by: Andreas Hindborg Link: https://lore.kernel.org/r/20250821193259.964504-5-lyude@redhat.com Signed-off-by: Andreas Hindborg --- rust/kernel/time/hrtimer.rs | 63 +++++++++++++++++++++++++++-- rust/kernel/time/hrtimer/arc.rs | 9 ++++- rust/kernel/time/hrtimer/pin.rs | 9 ++++- rust/kernel/time/hrtimer/pin_mut.rs | 12 ++++-- rust/kernel/time/hrtimer/tbox.rs | 9 ++++- 5 files changed, 93 insertions(+), 9 deletions(-) diff --git a/rust/kernel/time/hrtimer.rs b/rust/kernel/time/hrtimer.rs index 79fed14b2d98e..1e8839d277292 100644 --- a/rust/kernel/time/hrtimer.rs +++ b/rust/kernel/time/hrtimer.rs @@ -69,7 +69,7 @@ use super::{ClockSource, Delta, Instant}; use crate::{prelude::*, types::Opaque}; -use core::marker::PhantomData; +use core::{marker::PhantomData, ptr::NonNull}; use pin_init::PinInit; /// A type-alias to refer to the [`Instant`] for a given `T` from [`HrTimer`]. @@ -196,6 +196,10 @@ impl HrTimer { /// expires after `now` and then returns the number of times the timer was forwarded by /// `interval`. /// + /// This function is mainly useful for timer types which can provide exclusive access to the + /// timer when the timer is not running. For forwarding the timer from within the timer callback + /// context, see [`HrTimerCallbackContext::forward()`]. + /// /// Returns the number of overruns that occurred as a result of the timer expiry change. pub fn forward(self: Pin<&mut Self>, now: HrTimerInstant, interval: Delta) -> u64 where @@ -345,9 +349,13 @@ pub trait HrTimerCallback { type Pointer<'a>: RawHrTimerCallback; /// Called by the timer logic when the timer fires. - fn run(this: as RawHrTimerCallback>::CallbackTarget<'_>) -> HrTimerRestart + fn run( + this: as RawHrTimerCallback>::CallbackTarget<'_>, + ctx: HrTimerCallbackContext<'_, Self>, + ) -> HrTimerRestart where - Self: Sized; + Self: Sized, + Self: HasHrTimer; } /// A handle representing a potentially running timer. @@ -632,6 +640,55 @@ impl HrTimerMode for RelativePinnedHardMode { type Expires = Delta; } +/// Privileged smart-pointer for a [`HrTimer`] callback context. +/// +/// Many [`HrTimer`] methods can only be called in two situations: +/// +/// * When the caller has exclusive access to the `HrTimer` and the `HrTimer` is guaranteed not to +/// be running. +/// * From within the context of an `HrTimer`'s callback method. +/// +/// This type provides access to said methods from within a timer callback context. +/// +/// # Invariants +/// +/// * The existence of this type means the caller is currently within the callback for an +/// [`HrTimer`]. +/// * `self.0` always points to a live instance of [`HrTimer`]. +pub struct HrTimerCallbackContext<'a, T: HasHrTimer>(NonNull>, PhantomData<&'a ()>); + +impl<'a, T: HasHrTimer> HrTimerCallbackContext<'a, T> { + /// Create a new [`HrTimerCallbackContext`]. + /// + /// # Safety + /// + /// This function relies on the caller being within the context of a timer callback, so it must + /// not be used anywhere except for within implementations of [`RawHrTimerCallback::run`]. The + /// caller promises that `timer` points to a valid initialized instance of + /// [`bindings::hrtimer`]. + /// + /// The returned `Self` must not outlive the function context of [`RawHrTimerCallback::run`] + /// where this function is called. + pub(crate) unsafe fn from_raw(timer: *mut HrTimer) -> Self { + // SAFETY: The caller guarantees `timer` is a valid pointer to an initialized + // `bindings::hrtimer` + // INVARIANT: Our safety contract ensures that we're within the context of a timer callback + // and that `timer` points to a live instance of `HrTimer`. + Self(unsafe { NonNull::new_unchecked(timer) }, PhantomData) + } + + /// Conditionally forward the timer. + /// + /// This function is identical to [`HrTimer::forward()`] except that it may only be used from + /// within the context of a [`HrTimer`] callback. + pub fn forward(&mut self, now: HrTimerInstant, interval: Delta) -> u64 { + // SAFETY: + // - We are guaranteed to be within the context of a timer callback by our type invariants + // - By our type invariants, `self.0` always points to a valid `HrTimer` + unsafe { HrTimer::::raw_forward(self.0.as_ptr(), now, interval) } + } +} + /// Use to implement the [`HasHrTimer`] trait. /// /// See [`module`] documentation for an example. diff --git a/rust/kernel/time/hrtimer/arc.rs b/rust/kernel/time/hrtimer/arc.rs index ed490a7a89503..7be82bcb352ac 100644 --- a/rust/kernel/time/hrtimer/arc.rs +++ b/rust/kernel/time/hrtimer/arc.rs @@ -3,6 +3,7 @@ use super::HasHrTimer; use super::HrTimer; use super::HrTimerCallback; +use super::HrTimerCallbackContext; use super::HrTimerHandle; use super::HrTimerMode; use super::HrTimerPointer; @@ -99,6 +100,12 @@ where // allocation from other `Arc` clones. let receiver = unsafe { ArcBorrow::from_raw(data_ptr) }; - T::run(receiver).into_c() + // SAFETY: + // - By C API contract `timer_ptr` is the pointer that we passed when queuing the timer, so + // it is a valid pointer to a `HrTimer` embedded in a `T`. + // - We are within `RawHrTimerCallback::run` + let context = unsafe { HrTimerCallbackContext::from_raw(timer_ptr) }; + + T::run(receiver, context).into_c() } } diff --git a/rust/kernel/time/hrtimer/pin.rs b/rust/kernel/time/hrtimer/pin.rs index aef16d9ee2f0c..4d39ef7816971 100644 --- a/rust/kernel/time/hrtimer/pin.rs +++ b/rust/kernel/time/hrtimer/pin.rs @@ -3,6 +3,7 @@ use super::HasHrTimer; use super::HrTimer; use super::HrTimerCallback; +use super::HrTimerCallbackContext; use super::HrTimerHandle; use super::HrTimerMode; use super::RawHrTimerCallback; @@ -103,6 +104,12 @@ where // here. let receiver_pin = unsafe { Pin::new_unchecked(receiver_ref) }; - T::run(receiver_pin).into_c() + // SAFETY: + // - By C API contract `timer_ptr` is the pointer that we passed when queuing the timer, so + // it is a valid pointer to a `HrTimer` embedded in a `T`. + // - We are within `RawHrTimerCallback::run` + let context = unsafe { HrTimerCallbackContext::from_raw(timer_ptr) }; + + T::run(receiver_pin, context).into_c() } } diff --git a/rust/kernel/time/hrtimer/pin_mut.rs b/rust/kernel/time/hrtimer/pin_mut.rs index 767d0a4e8a2c1..9d9447d4d57e8 100644 --- a/rust/kernel/time/hrtimer/pin_mut.rs +++ b/rust/kernel/time/hrtimer/pin_mut.rs @@ -1,8 +1,8 @@ // SPDX-License-Identifier: GPL-2.0 use super::{ - HasHrTimer, HrTimer, HrTimerCallback, HrTimerHandle, HrTimerMode, RawHrTimerCallback, - UnsafeHrTimerPointer, + HasHrTimer, HrTimer, HrTimerCallback, HrTimerCallbackContext, HrTimerHandle, HrTimerMode, + RawHrTimerCallback, UnsafeHrTimerPointer, }; use core::{marker::PhantomData, pin::Pin, ptr::NonNull}; @@ -107,6 +107,12 @@ where // here. let receiver_pin = unsafe { Pin::new_unchecked(receiver_ref) }; - T::run(receiver_pin).into_c() + // SAFETY: + // - By C API contract `timer_ptr` is the pointer that we passed when queuing the timer, so + // it is a valid pointer to a `HrTimer` embedded in a `T`. + // - We are within `RawHrTimerCallback::run` + let context = unsafe { HrTimerCallbackContext::from_raw(timer_ptr) }; + + T::run(receiver_pin, context).into_c() } } diff --git a/rust/kernel/time/hrtimer/tbox.rs b/rust/kernel/time/hrtimer/tbox.rs index ec08303315f28..aa1ee31a71953 100644 --- a/rust/kernel/time/hrtimer/tbox.rs +++ b/rust/kernel/time/hrtimer/tbox.rs @@ -3,6 +3,7 @@ use super::HasHrTimer; use super::HrTimer; use super::HrTimerCallback; +use super::HrTimerCallbackContext; use super::HrTimerHandle; use super::HrTimerMode; use super::HrTimerPointer; @@ -119,6 +120,12 @@ where // `data_ptr` exist. let data_mut_ref = unsafe { Pin::new_unchecked(&mut *data_ptr) }; - T::run(data_mut_ref).into_c() + // SAFETY: + // - By C API contract `timer_ptr` is the pointer that we passed when queuing the timer, so + // it is a valid pointer to a `HrTimer` embedded in a `T`. + // - We are within `RawHrTimerCallback::run` + let context = unsafe { HrTimerCallbackContext::from_raw(timer_ptr) }; + + T::run(data_mut_ref, context).into_c() } } -- 2.47.3