Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix usage of autodiff macro with inner functions #138314

Merged
merged 4 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 77 additions & 48 deletions compiler/rustc_builtin_macros/src/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ mod llvm_enzyme {
use rustc_ast::visit::AssocCtxt::*;
use rustc_ast::{
self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
MetaItemInner, PatKind, QSelf, TyKind,
MetaItemInner, PatKind, QSelf, TyKind, Visibility,
};
use rustc_expand::base::{Annotatable, ExtCtxt};
use rustc_span::{Ident, Span, Symbol, kw, sym};
Expand Down Expand Up @@ -72,6 +72,16 @@ mod llvm_enzyme {
}
}

// Get information about the function the macro is applied to
fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident)> {
match &iitem.kind {
ItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
Some((iitem.vis.clone(), sig.clone(), ident.clone()))
}
_ => None,
}
}

pub(crate) fn from_ast(
ecx: &mut ExtCtxt<'_>,
meta_item: &ThinVec<MetaItemInner>,
Expand Down Expand Up @@ -199,32 +209,26 @@ mod llvm_enzyme {
return vec![item];
}
let dcx = ecx.sess.dcx();
// first get the annotable item:
let (primal, sig, is_impl): (Ident, FnSig, bool) = match &item {
Annotatable::Item(iitem) => {
let (ident, sig) = match &iitem.kind {
ItemKind::Fn(box ast::Fn { ident, sig, .. }) => (ident, sig),
_ => {
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
return vec![item];
}
};
(*ident, sig.clone(), false)
}

// first get information about the annotable item:
let Some((vis, sig, primal)) = (match &item {
Annotatable::Item(iitem) => extract_item_info(iitem),
Annotatable::Stmt(stmt) => match &stmt.kind {
ast::StmtKind::Item(iitem) => extract_item_info(iitem),
_ => None,
},
Annotatable::AssocItem(assoc_item, Impl { of_trait: false }) => {
let (ident, sig) = match &assoc_item.kind {
ast::AssocItemKind::Fn(box ast::Fn { ident, sig, .. }) => (ident, sig),
_ => {
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
return vec![item];
match &assoc_item.kind {
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
Some((assoc_item.vis.clone(), sig.clone(), ident.clone()))
}
};
(*ident, sig.clone(), true)
}
_ => {
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
return vec![item];
_ => None,
}
}
_ => None,
}) else {
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
return vec![item];
};

let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
Expand All @@ -238,15 +242,6 @@ mod llvm_enzyme {
let has_ret = has_ret(&sig.decl.output);
let sig_span = ecx.with_call_site_ctxt(sig.span);

let vis = match &item {
Annotatable::Item(iitem) => iitem.vis.clone(),
Annotatable::AssocItem(assoc_item, _) => assoc_item.vis.clone(),
_ => {
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
return vec![item];
}
};

// create TokenStream from vec elemtents:
// meta_item doesn't have a .tokens field
let mut ts: Vec<TokenTree> = vec![];
Expand Down Expand Up @@ -379,6 +374,22 @@ mod llvm_enzyme {
}
Annotatable::AssocItem(assoc_item.clone(), i)
}
Annotatable::Stmt(ref mut stmt) => {
match stmt.kind {
ast::StmtKind::Item(ref mut iitem) => {
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
iitem.attrs.push(attr);
}
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind))
{
iitem.attrs.push(inline_never.clone());
}
}
_ => unreachable!("stmt kind checked previously"),
};

Annotatable::Stmt(stmt.clone())
}
_ => {
unreachable!("annotatable kind checked previously")
}
Expand All @@ -389,22 +400,40 @@ mod llvm_enzyme {
delim: rustc_ast::token::Delimiter::Parenthesis,
tokens: ts,
});

let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
let d_annotatable = if is_impl {
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
let d_fn = P(ast::AssocItem {
attrs: thin_vec![d_attr, inline_never],
id: ast::DUMMY_NODE_ID,
span,
vis,
kind: assoc_item,
tokens: None,
});
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
} else {
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
d_fn.vis = vis;
Annotatable::Item(d_fn)
let d_annotatable = match &item {
Annotatable::AssocItem(_, _) => {
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
let d_fn = P(ast::AssocItem {
attrs: thin_vec![d_attr, inline_never],
id: ast::DUMMY_NODE_ID,
span,
vis,
kind: assoc_item,
tokens: None,
});
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
}
Annotatable::Item(_) => {
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
d_fn.vis = vis;

Annotatable::Item(d_fn)
}
Annotatable::Stmt(_) => {
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
d_fn.vis = vis;

Annotatable::Stmt(P(ast::Stmt {
id: ast::DUMMY_NODE_ID,
kind: ast::StmtKind::Item(d_fn),
span,
}))
}
_ => {
unreachable!("item kind checked previously")
}
};

return vec![orig_annotatable, d_annotatable];
Expand Down
23 changes: 23 additions & 0 deletions tests/pretty/autodiff_forward.pp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
// Make sure, that we add the None for the default return.


// We want to make sure that we can use the macro for functions defined inside of functions

::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
Expand Down Expand Up @@ -158,4 +160,25 @@
::core::hint::black_box((bx_0,));
::core::hint::black_box(<f32>::default())
}
pub fn f9() {
#[rustc_autodiff]
#[inline(never)]
fn inner(x: f32) -> f32 { x * x }
#[rustc_autodiff(Forward, 1, Dual, Dual)]
#[inline(never)]
fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(inner(x));
::core::hint::black_box((bx_0,));
::core::hint::black_box(<(f32, f32)>::default())
}
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
#[inline(never)]
fn d_inner_1(x: f32, bx_0: f32) -> f32 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(inner(x));
::core::hint::black_box((bx_0,));
::core::hint::black_box(<f32>::default())
}
}
fn main() {}
9 changes: 9 additions & 0 deletions tests/pretty/autodiff_forward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,13 @@ fn f8(x: &f32) -> f32 {
unimplemented!()
}

// We want to make sure that we can use the macro for functions defined inside of functions
pub fn f9() {
#[autodiff(d_inner_1, Forward, Dual, DualOnly)]
#[autodiff(d_inner_2, Forward, Dual, Dual)]
fn inner(x: f32) -> f32 {
x * x
}
}

fn main() {}
Loading