]> git.ipfire.org Git - thirdparty/linux.git/commitdiff
rust: macros: convert `#[vtable]` macro to use `syn`
authorGary Guo <gary@garyguo.net>
Mon, 12 Jan 2026 17:07:14 +0000 (17:07 +0000)
committerMiguel Ojeda <ojeda@kernel.org>
Tue, 27 Jan 2026 23:55:24 +0000 (00:55 +0100)
`#[vtable]` is converted to use syn. This is more robust than the
previous heuristic-based searching of defined methods and functions.

When doing so, the trait and impl are split into two code paths as the
types are distinct when parsed by `syn`.

Reviewed-by: Tamir Duberstein <tamird@gmail.com>
Signed-off-by: Gary Guo <gary@garyguo.net>
Reviewed-by: Benno Lossin <lossin@kernel.org>
Link: https://patch.msgid.link/20260112170919.1888584-4-gary@kernel.org
Signed-off-by: Miguel Ojeda <ojeda@kernel.org>
rust/macros/lib.rs
rust/macros/vtable.rs

index b884ea17391b07b0eec5f1863e2b43b6573ce92f..0ecbb2e16da3cbf0adc99a86c5b06fd8846d8309 100644 (file)
@@ -22,6 +22,8 @@ mod vtable;
 
 use proc_macro::TokenStream;
 
+use syn::parse_macro_input;
+
 /// Declares a kernel module.
 ///
 /// The `type` argument should be a type which implements the [`Module`]
@@ -204,8 +206,11 @@ pub fn module(ts: TokenStream) -> TokenStream {
 ///
 /// [`kernel::error::VTABLE_DEFAULT_ERROR`]: ../kernel/error/constant.VTABLE_DEFAULT_ERROR.html
 #[proc_macro_attribute]
-pub fn vtable(attr: TokenStream, ts: TokenStream) -> TokenStream {
-    vtable::vtable(attr.into(), ts.into()).into()
+pub fn vtable(attr: TokenStream, input: TokenStream) -> TokenStream {
+    parse_macro_input!(attr as syn::parse::Nothing);
+    vtable::vtable(parse_macro_input!(input))
+        .unwrap_or_else(|e| e.into_compile_error())
+        .into()
 }
 
 /// Export a function so that C code can call it via a header file.
index a67d1cc81a2d3673cee63820037fbd03b9db4ea7..72ae0a1816a045d3f598a5b81418eb7dd974b691 100644 (file)
 // SPDX-License-Identifier: GPL-2.0
 
-use std::collections::HashSet;
-use std::fmt::Write;
+use std::{
+    collections::HashSet,
+    iter::Extend, //
+};
 
-use proc_macro2::{Delimiter, Group, TokenStream, TokenTree};
+use proc_macro2::{
+    Ident,
+    TokenStream, //
+};
+use quote::ToTokens;
+use syn::{
+    parse_quote,
+    Error,
+    ImplItem,
+    Item,
+    ItemImpl,
+    ItemTrait,
+    Result,
+    TraitItem, //
+};
 
-pub(crate) fn vtable(_attr: TokenStream, ts: TokenStream) -> TokenStream {
-    let mut tokens: Vec<_> = ts.into_iter().collect();
+fn handle_trait(mut item: ItemTrait) -> Result<ItemTrait> {
+    let mut gen_items = Vec::new();
+    let mut gen_consts = HashSet::new();
 
-    // Scan for the `trait` or `impl` keyword.
-    let is_trait = tokens
-        .iter()
-        .find_map(|token| match token {
-            TokenTree::Ident(ident) => match ident.to_string().as_str() {
-                "trait" => Some(true),
-                "impl" => Some(false),
-                _ => None,
-            },
-            _ => None,
-        })
-        .expect("#[vtable] attribute should only be applied to trait or impl block");
+    gen_items.push(parse_quote! {
+         /// A marker to prevent implementors from forgetting to use [`#[vtable]`](vtable)
+         /// attribute when implementing this trait.
+         const USE_VTABLE_ATTR: ();
+    });
 
-    // Retrieve the main body. The main body should be the last token tree.
-    let body = match tokens.pop() {
-        Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group,
-        _ => panic!("cannot locate main body of trait or impl block"),
-    };
-
-    let mut body_it = body.stream().into_iter();
-    let mut functions = Vec::new();
-    let mut consts = HashSet::new();
-    while let Some(token) = body_it.next() {
-        match token {
-            TokenTree::Ident(ident) if ident == "fn" => {
-                let fn_name = match body_it.next() {
-                    Some(TokenTree::Ident(ident)) => ident.to_string(),
-                    // Possibly we've encountered a fn pointer type instead.
-                    _ => continue,
-                };
-                functions.push(fn_name);
-            }
-            TokenTree::Ident(ident) if ident == "const" => {
-                let const_name = match body_it.next() {
-                    Some(TokenTree::Ident(ident)) => ident.to_string(),
-                    // Possibly we've encountered an inline const block instead.
-                    _ => continue,
-                };
-                consts.insert(const_name);
+    for item in &item.items {
+        if let TraitItem::Fn(fn_item) = item {
+            let name = &fn_item.sig.ident;
+            let gen_const_name = Ident::new(
+                &format!("HAS_{}", name.to_string().to_uppercase()),
+                name.span(),
+            );
+            // Skip if it's declared already -- this can happen if `#[cfg]` is used to selectively
+            // define functions.
+            // FIXME: `#[cfg]` should be copied and propagated to the generated consts.
+            if gen_consts.contains(&gen_const_name) {
+                continue;
             }
-            _ => (),
+
+            // We don't know on the implementation-site whether a method is required or provided
+            // so we have to generate a const for all methods.
+            let comment =
+                format!("Indicates if the `{name}` method is overridden by the implementor.");
+            gen_items.push(parse_quote! {
+                #[doc = #comment]
+                const #gen_const_name: bool = false;
+            });
+            gen_consts.insert(gen_const_name);
         }
     }
 
-    let mut const_items;
-    if is_trait {
-        const_items = "
-                /// A marker to prevent implementors from forgetting to use [`#[vtable]`](vtable)
-                /// attribute when implementing this trait.
-                const USE_VTABLE_ATTR: ();
-        "
-        .to_owned();
+    item.items.extend(gen_items);
+    Ok(item)
+}
 
-        for f in functions {
-            let gen_const_name = format!("HAS_{}", f.to_uppercase());
-            // Skip if it's declared already -- this allows user override.
-            if consts.contains(&gen_const_name) {
-                continue;
-            }
-            // We don't know on the implementation-site whether a method is required or provided
-            // so we have to generate a const for all methods.
-            write!(
-                const_items,
-                "/// Indicates if the `{f}` method is overridden by the implementor.
-                const {gen_const_name}: bool = false;",
-            )
-            .unwrap();
-            consts.insert(gen_const_name);
+fn handle_impl(mut item: ItemImpl) -> Result<ItemImpl> {
+    let mut gen_items = Vec::new();
+    let mut defined_consts = HashSet::new();
+
+    // Iterate over all user-defined constants to gather any possible explicit overrides.
+    for item in &item.items {
+        if let ImplItem::Const(const_item) = item {
+            defined_consts.insert(const_item.ident.clone());
         }
-    } else {
-        const_items = "const USE_VTABLE_ATTR: () = ();".to_owned();
+    }
+
+    gen_items.push(parse_quote! {
+        const USE_VTABLE_ATTR: () = ();
+    });
 
-        for f in functions {
-            let gen_const_name = format!("HAS_{}", f.to_uppercase());
-            if consts.contains(&gen_const_name) {
+    for item in &item.items {
+        if let ImplItem::Fn(fn_item) = item {
+            let name = &fn_item.sig.ident;
+            let gen_const_name = Ident::new(
+                &format!("HAS_{}", name.to_string().to_uppercase()),
+                name.span(),
+            );
+            // Skip if it's declared already -- this allows user override.
+            if defined_consts.contains(&gen_const_name) {
                 continue;
             }
-            write!(const_items, "const {gen_const_name}: bool = true;").unwrap();
+            gen_items.push(parse_quote! {
+                const #gen_const_name: bool = true;
+            });
+            defined_consts.insert(gen_const_name);
         }
     }
 
-    let new_body = vec![const_items.parse().unwrap(), body.stream()]
-        .into_iter()
-        .collect();
-    tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, new_body)));
-    tokens.into_iter().collect()
+    item.items.extend(gen_items);
+    Ok(item)
+}
+
+pub(crate) fn vtable(input: Item) -> Result<TokenStream> {
+    match input {
+        Item::Trait(item) => Ok(handle_trait(item)?.into_token_stream()),
+        Item::Impl(item) => Ok(handle_impl(item)?.into_token_stream()),
+        _ => Err(Error::new_spanned(
+            input,
+            "`#[vtable]` attribute should only be applied to trait or impl block",
+        ))?,
+    }
 }