From: Danilo Krummrich Date: Fri, 14 Mar 2025 16:09:06 +0000 (+0100) Subject: rust: pci: fix unrestricted &mut pci::Device X-Git-Tag: v6.15-rc1~79^2~5 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=7b948a2af6b5;p=thirdparty%2Flinux.git rust: pci: fix unrestricted &mut pci::Device As by now, pci::Device is implemented as: #[derive(Clone)] pub struct Device(ARef); This may be convenient, but has the implication that drivers can call device methods that require a mutable reference concurrently at any point of time. Instead define pci::Device as pub struct Device( Opaque, PhantomData, ); and manually implement the AlwaysRefCounted trait. With this we can implement methods that should only be called from bus callbacks (such as probe()) for pci::Device. Consequently, we make this type accessible in bus callbacks only. Arbitrary references taken by the driver are still of type ARef and hence don't provide access to methods that are reserved for bus callbacks. Fixes: 1bd8b6b2c5d3 ("rust: pci: add basic PCI device / driver abstractions") Reviewed-by: Benno Lossin Signed-off-by: Danilo Krummrich Acked-by: Boqun Feng Link: https://lore.kernel.org/r/20250314160932.100165-4-dakr@kernel.org Signed-off-by: Greg Kroah-Hartman --- diff --git a/rust/kernel/pci.rs b/rust/kernel/pci.rs index 386484dcf36eb..0ac6cef74f815 100644 --- a/rust/kernel/pci.rs +++ b/rust/kernel/pci.rs @@ -6,7 +6,7 @@ use crate::{ alloc::flags::*, - bindings, container_of, device, + bindings, device, device_id::RawDeviceId, devres::Devres, driver, @@ -17,7 +17,11 @@ use crate::{ types::{ARef, ForeignOwnable, Opaque}, ThisModule, }; -use core::{ops::Deref, ptr::addr_of_mut}; +use core::{ + marker::PhantomData, + ops::Deref, + ptr::{addr_of_mut, NonNull}, +}; use kernel::prelude::*; /// An adapter for the registration of PCI drivers. @@ -60,17 +64,16 @@ impl Adapter { ) -> kernel::ffi::c_int { // SAFETY: The PCI bus only ever calls the probe callback with a valid pointer to a // `struct pci_dev`. - let dev = unsafe { device::Device::get_device(addr_of_mut!((*pdev).dev)) }; - // SAFETY: `dev` is guaranteed to be embedded in a valid `struct pci_dev` by the call - // above. - let mut pdev = unsafe { Device::from_dev(dev) }; + // + // INVARIANT: `pdev` is valid for the duration of `probe_callback()`. + let pdev = unsafe { &*pdev.cast::>() }; // SAFETY: `DeviceId` is a `#[repr(transparent)` wrapper of `struct pci_device_id` and // does not add additional invariants, so it's safe to transmute. let id = unsafe { &*id.cast::() }; let info = T::ID_TABLE.info(id.index()); - match T::probe(&mut pdev, info) { + match T::probe(pdev, info) { Ok(data) => { // Let the `struct pci_dev` own a reference of the driver's private data. // SAFETY: By the type invariant `pdev.as_raw` returns a valid pointer to a @@ -192,7 +195,7 @@ macro_rules! pci_device_table { /// # Example /// ///``` -/// # use kernel::{bindings, pci}; +/// # use kernel::{bindings, device::Core, pci}; /// /// struct MyDriver; /// @@ -210,7 +213,7 @@ macro_rules! pci_device_table { /// const ID_TABLE: pci::IdTable = &PCI_TABLE; /// /// fn probe( -/// _pdev: &mut pci::Device, +/// _pdev: &pci::Device, /// _id_info: &Self::IdInfo, /// ) -> Result>> { /// Err(ENODEV) @@ -234,20 +237,23 @@ pub trait Driver { /// /// Called when a new platform device is added or discovered. /// Implementers should attempt to initialize the device here. - fn probe(dev: &mut Device, id_info: &Self::IdInfo) -> Result>>; + fn probe(dev: &Device, id_info: &Self::IdInfo) -> Result>>; } /// The PCI device representation. /// -/// A PCI device is based on an always reference counted `device:Device` instance. Cloning a PCI -/// device, hence, also increments the base device' reference count. +/// This structure represents the Rust abstraction for a C `struct pci_dev`. The implementation +/// abstracts the usage of an already existing C `struct pci_dev` within Rust code that we get +/// passed from the C side. /// /// # Invariants /// -/// `Device` hold a valid reference of `ARef` whose underlying `struct device` is a -/// member of a `struct pci_dev`. -#[derive(Clone)] -pub struct Device(ARef); +/// A [`Device`] instance represents a valid `struct device` created by the C portion of the kernel. +#[repr(transparent)] +pub struct Device( + Opaque, + PhantomData, +); /// A PCI BAR to perform I/O-Operations on. /// @@ -256,13 +262,13 @@ pub struct Device(ARef); /// `Bar` always holds an `IoRaw` inststance that holds a valid pointer to the start of the I/O /// memory mapped PCI bar and its size. pub struct Bar { - pdev: Device, + pdev: ARef, io: IoRaw, num: i32, } impl Bar { - fn new(pdev: Device, num: u32, name: &CStr) -> Result { + fn new(pdev: &Device, num: u32, name: &CStr) -> Result { let len = pdev.resource_len(num)?; if len == 0 { return Err(ENOMEM); @@ -300,12 +306,16 @@ impl Bar { // `pdev` is valid by the invariants of `Device`. // `ioptr` is guaranteed to be the start of a valid I/O mapped memory region. // `num` is checked for validity by a previous call to `Device::resource_len`. - unsafe { Self::do_release(&pdev, ioptr, num) }; + unsafe { Self::do_release(pdev, ioptr, num) }; return Err(err); } }; - Ok(Bar { pdev, io, num }) + Ok(Bar { + pdev: pdev.into(), + io, + num, + }) } /// # Safety @@ -351,20 +361,8 @@ impl Deref for Bar { } impl Device { - /// Create a PCI Device instance from an existing `device::Device`. - /// - /// # Safety - /// - /// `dev` must be an `ARef` whose underlying `bindings::device` is a member of - /// a `bindings::pci_dev`. - pub unsafe fn from_dev(dev: ARef) -> Self { - Self(dev) - } - fn as_raw(&self) -> *mut bindings::pci_dev { - // SAFETY: By the type invariant `self.0.as_raw` is a pointer to the `struct device` - // embedded in `struct pci_dev`. - unsafe { container_of!(self.0.as_raw(), bindings::pci_dev, dev) as _ } + self.0.get() } /// Returns the PCI vendor ID. @@ -379,18 +377,6 @@ impl Device { unsafe { (*self.as_raw()).device } } - /// Enable memory resources for this device. - pub fn enable_device_mem(&self) -> Result { - // SAFETY: `self.as_raw` is guaranteed to be a pointer to a valid `struct pci_dev`. - to_result(unsafe { bindings::pci_enable_device_mem(self.as_raw()) }) - } - - /// Enable bus-mastering for this device. - pub fn set_master(&self) { - // SAFETY: `self.as_raw` is guaranteed to be a pointer to a valid `struct pci_dev`. - unsafe { bindings::pci_set_master(self.as_raw()) }; - } - /// Returns the size of the given PCI bar resource. pub fn resource_len(&self, bar: u32) -> Result { if !Bar::index_is_valid(bar) { @@ -410,7 +396,7 @@ impl Device { bar: u32, name: &CStr, ) -> Result>> { - let bar = Bar::::new(self.clone(), bar, name)?; + let bar = Bar::::new(self, bar, name)?; let devres = Devres::new(self.as_ref(), bar, GFP_KERNEL)?; Ok(devres) @@ -422,8 +408,60 @@ impl Device { } } +impl Device { + /// Enable memory resources for this device. + pub fn enable_device_mem(&self) -> Result { + // SAFETY: `self.as_raw` is guaranteed to be a pointer to a valid `struct pci_dev`. + to_result(unsafe { bindings::pci_enable_device_mem(self.as_raw()) }) + } + + /// Enable bus-mastering for this device. + pub fn set_master(&self) { + // SAFETY: `self.as_raw` is guaranteed to be a pointer to a valid `struct pci_dev`. + unsafe { bindings::pci_set_master(self.as_raw()) }; + } +} + +impl Deref for Device { + type Target = Device; + + fn deref(&self) -> &Self::Target { + let ptr: *const Self = self; + + // CAST: `Device` is a transparent wrapper of `Opaque`. + let ptr = ptr.cast::(); + + // SAFETY: `ptr` was derived from `&self`. + unsafe { &*ptr } + } +} + +impl From<&Device> for ARef { + fn from(dev: &Device) -> Self { + (&**dev).into() + } +} + +// SAFETY: Instances of `Device` are always reference-counted. +unsafe impl crate::types::AlwaysRefCounted for Device { + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference guarantees that the refcount is non-zero. + unsafe { bindings::pci_dev_get(self.as_raw()) }; + } + + unsafe fn dec_ref(obj: NonNull) { + // SAFETY: The safety requirements guarantee that the refcount is non-zero. + unsafe { bindings::pci_dev_put(obj.cast().as_ptr()) } + } +} + impl AsRef for Device { fn as_ref(&self) -> &device::Device { - &self.0 + // SAFETY: By the type invariant of `Self`, `self.as_raw()` is a pointer to a valid + // `struct pci_dev`. + let dev = unsafe { addr_of_mut!((*self.as_raw()).dev) }; + + // SAFETY: `dev` points to a valid `struct device`. + unsafe { device::Device::as_ref(dev) } } } diff --git a/samples/rust/rust_driver_pci.rs b/samples/rust/rust_driver_pci.rs index ddc52db71a82a..6cee912b6a66f 100644 --- a/samples/rust/rust_driver_pci.rs +++ b/samples/rust/rust_driver_pci.rs @@ -4,7 +4,7 @@ //! //! To make this driver probe, QEMU must be run with `-device pci-testdev`. -use kernel::{bindings, c_str, devres::Devres, pci, prelude::*}; +use kernel::{bindings, c_str, device::Core, devres::Devres, pci, prelude::*, types::ARef}; struct Regs; @@ -26,7 +26,7 @@ impl TestIndex { } struct SampleDriver { - pdev: pci::Device, + pdev: ARef, bar: Devres, } @@ -62,7 +62,7 @@ impl pci::Driver for SampleDriver { const ID_TABLE: pci::IdTable = &PCI_TABLE; - fn probe(pdev: &mut pci::Device, info: &Self::IdInfo) -> Result>> { + fn probe(pdev: &pci::Device, info: &Self::IdInfo) -> Result>> { dev_dbg!( pdev.as_ref(), "Probe Rust PCI driver sample (PCI ID: 0x{:x}, 0x{:x}).\n", @@ -77,7 +77,7 @@ impl pci::Driver for SampleDriver { let drvdata = KBox::new( Self { - pdev: pdev.clone(), + pdev: pdev.into(), bar, }, GFP_KERNEL,