]> git.ipfire.org Git - thirdparty/kernel/linux.git/commitdiff
poll: rust: allow poll_table ptrs to be null
authorAlice Ryhl <aliceryhl@google.com>
Mon, 23 Jun 2025 13:57:27 +0000 (13:57 +0000)
committerChristian Brauner <brauner@kernel.org>
Mon, 14 Jul 2025 12:12:24 +0000 (14:12 +0200)
It's possible for a poll_table to be null. This can happen if an
end-user just wants to know if a resource has events right now without
registering a waiter for when events become available. Furthermore,
these null pointers should be handled transparently by the API, so we
should not change `from_ptr` to return an `Option`. Thus, change
`PollTable` to wrap a raw pointer rather than use a reference so that
you can pass null.

Comments mentioning `struct poll_table` are changed to just `poll_table`
since `poll_table` is a typedef. (It's a typedef because it's supposed
to be opaque.)

Reviewed-by: Benno Lossin <lossin@kernel.org>
Signed-off-by: Alice Ryhl <aliceryhl@google.com>
rust/helpers/helpers.c
rust/helpers/poll.c [new file with mode: 0644]
rust/kernel/sync/poll.rs

index 0f1b5d11598591bc62bb6439747211af164b76d6..ff65073fe3a5220e7792306fc9ae74c416935a52 100644 (file)
@@ -30,6 +30,7 @@
 #include "platform.c"
 #include "pci.c"
 #include "pid_namespace.c"
+#include "poll.c"
 #include "rbtree.c"
 #include "rcu.c"
 #include "refcount.c"
diff --git a/rust/helpers/poll.c b/rust/helpers/poll.c
new file mode 100644 (file)
index 0000000..7e5b175
--- /dev/null
@@ -0,0 +1,10 @@
+// SPDX-License-Identifier: GPL-2.0
+
+#include <linux/export.h>
+#include <linux/poll.h>
+
+void rust_helper_poll_wait(struct file *filp, wait_queue_head_t *wait_address,
+                          poll_table *p)
+{
+       poll_wait(filp, wait_address, p);
+}
index d7e6e59e124b09d5f0d62244ca46089e0748d544..69f1368a2151c489bdb95705d1ddac1c79bf566b 100644 (file)
@@ -9,9 +9,8 @@ use crate::{
     fs::File,
     prelude::*,
     sync::{CondVar, LockClassKey},
-    types::Opaque,
 };
-use core::ops::Deref;
+use core::{marker::PhantomData, ops::Deref};
 
 /// Creates a [`PollCondVar`] initialiser with the given name and a newly-created lock class.
 #[macro_export]
@@ -23,58 +22,43 @@ macro_rules! new_poll_condvar {
     };
 }
 
-/// Wraps the kernel's `struct poll_table`.
+/// Wraps the kernel's `poll_table`.
 ///
 /// # Invariants
 ///
-/// This struct contains a valid `struct poll_table`.
-///
-/// For a `struct poll_table` to be valid, its `_qproc` function must follow the safety
-/// requirements of `_qproc` functions:
-///
-/// * The `_qproc` function is given permission to enqueue a waiter to the provided `poll_table`
-///   during the call. Once the waiter is removed and an rcu grace period has passed, it must no
-///   longer access the `wait_queue_head`.
+/// The pointer must be null or reference a valid `poll_table`.
 #[repr(transparent)]
-pub struct PollTable(Opaque<bindings::poll_table>);
+pub struct PollTable<'a> {
+    table: *mut bindings::poll_table,
+    _lifetime: PhantomData<&'a bindings::poll_table>,
+}
 
-impl PollTable {
-    /// Creates a reference to a [`PollTable`] from a valid pointer.
+impl<'a> PollTable<'a> {
+    /// Creates a [`PollTable`] from a valid pointer.
     ///
     /// # Safety
     ///
-    /// The caller must ensure that for the duration of `'a`, the pointer will point at a valid poll
-    /// table (as defined in the type invariants).
-    ///
-    /// The caller must also ensure that the `poll_table` is only accessed via the returned
-    /// reference for the duration of `'a`.
-    pub unsafe fn from_ptr<'a>(ptr: *mut bindings::poll_table) -> &'a mut PollTable {
-        // SAFETY: The safety requirements guarantee the validity of the dereference, while the
-        // `PollTable` type being transparent makes the cast ok.
-        unsafe { &mut *ptr.cast() }
-    }
-
-    fn get_qproc(&self) -> bindings::poll_queue_proc {
-        let ptr = self.0.get();
-        // SAFETY: The `ptr` is valid because it originates from a reference, and the `_qproc`
-        // field is not modified concurrently with this call since we have an immutable reference.
-        unsafe { (*ptr)._qproc }
+    /// The pointer must be null or reference a valid `poll_table` for the duration of `'a`.
+    pub unsafe fn from_raw(table: *mut bindings::poll_table) -> Self {
+        // INVARIANTS: The safety requirements are the same as the struct invariants.
+        PollTable {
+            table,
+            _lifetime: PhantomData,
+        }
     }
 
     /// Register this [`PollTable`] with the provided [`PollCondVar`], so that it can be notified
     /// using the condition variable.
-    pub fn register_wait(&mut self, file: &File, cv: &PollCondVar) {
-        if let Some(qproc) = self.get_qproc() {
-            // SAFETY: The pointers to `file` and `self` need to be valid for the duration of this
-            // call to `qproc`, which they are because they are references.
-            //
-            // The `cv.wait_queue_head` pointer must be valid until an rcu grace period after the
-            // waiter is removed. The `PollCondVar` is pinned, so before `cv.wait_queue_head` can
-            // be destroyed, the destructor must run. That destructor first removes all waiters,
-            // and then waits for an rcu grace period. Therefore, `cv.wait_queue_head` is valid for
-            // long enough.
-            unsafe { qproc(file.as_ptr() as _, cv.wait_queue_head.get(), self.0.get()) };
-        }
+    pub fn register_wait(&self, file: &File, cv: &PollCondVar) {
+        // SAFETY:
+        // * `file.as_ptr()` references a valid file for the duration of this call.
+        // * `self.table` is null or references a valid poll_table for the duration of this call.
+        // * Since `PollCondVar` is pinned, its destructor is guaranteed to run before the memory
+        //   containing `cv.wait_queue_head` is invalidated. Since the destructor clears all
+        //   waiters and then waits for an rcu grace period, it's guaranteed that
+        //   `cv.wait_queue_head` remains valid for at least an rcu grace period after the removal
+        //   of the last waiter.
+        unsafe { bindings::poll_wait(file.as_ptr(), cv.wait_queue_head.get(), self.table) }
     }
 }