diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 7f99f75b2b9df..351413dea493c 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -17,7 +17,7 @@ mod llvm_enzyme { use rustc_ast::visit::AssocCtxt::*; use rustc_ast::{ self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind, - MetaItemInner, PatKind, QSelf, TyKind, + MetaItemInner, PatKind, QSelf, TyKind, Visibility, }; use rustc_expand::base::{Annotatable, ExtCtxt}; use rustc_span::{Ident, Span, Symbol, kw, sym}; @@ -72,6 +72,16 @@ mod llvm_enzyme { } } + // Get information about the function the macro is applied to + fn extract_item_info(iitem: &P) -> Option<(Visibility, FnSig, Ident)> { + match &iitem.kind { + ItemKind::Fn(box ast::Fn { sig, ident, .. }) => { + Some((iitem.vis.clone(), sig.clone(), ident.clone())) + } + _ => None, + } + } + pub(crate) fn from_ast( ecx: &mut ExtCtxt<'_>, meta_item: &ThinVec, @@ -199,32 +209,26 @@ mod llvm_enzyme { return vec![item]; } let dcx = ecx.sess.dcx(); - // first get the annotable item: - let (primal, sig, is_impl): (Ident, FnSig, bool) = match &item { - Annotatable::Item(iitem) => { - let (ident, sig) = match &iitem.kind { - ItemKind::Fn(box ast::Fn { ident, sig, .. }) => (ident, sig), - _ => { - dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); - return vec![item]; - } - }; - (*ident, sig.clone(), false) - } + + // first get information about the annotable item: + let Some((vis, sig, primal)) = (match &item { + Annotatable::Item(iitem) => extract_item_info(iitem), + Annotatable::Stmt(stmt) => match &stmt.kind { + ast::StmtKind::Item(iitem) => extract_item_info(iitem), + _ => None, + }, Annotatable::AssocItem(assoc_item, Impl { of_trait: false }) => { - let (ident, sig) = match &assoc_item.kind { - ast::AssocItemKind::Fn(box ast::Fn { ident, sig, .. }) => (ident, sig), - _ => { - dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); - return vec![item]; + match &assoc_item.kind { + ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => { + Some((assoc_item.vis.clone(), sig.clone(), ident.clone())) } - }; - (*ident, sig.clone(), true) - } - _ => { - dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); - return vec![item]; + _ => None, + } } + _ => None, + }) else { + dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); + return vec![item]; }; let meta_item_vec: ThinVec = match meta_item.kind { @@ -238,15 +242,6 @@ mod llvm_enzyme { let has_ret = has_ret(&sig.decl.output); let sig_span = ecx.with_call_site_ctxt(sig.span); - let vis = match &item { - Annotatable::Item(iitem) => iitem.vis.clone(), - Annotatable::AssocItem(assoc_item, _) => assoc_item.vis.clone(), - _ => { - dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); - return vec![item]; - } - }; - // create TokenStream from vec elemtents: // meta_item doesn't have a .tokens field let mut ts: Vec = vec![]; @@ -379,6 +374,22 @@ mod llvm_enzyme { } Annotatable::AssocItem(assoc_item.clone(), i) } + Annotatable::Stmt(ref mut stmt) => { + match stmt.kind { + ast::StmtKind::Item(ref mut iitem) => { + if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) { + iitem.attrs.push(attr); + } + if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) + { + iitem.attrs.push(inline_never.clone()); + } + } + _ => unreachable!("stmt kind checked previously"), + }; + + Annotatable::Stmt(stmt.clone()) + } _ => { unreachable!("annotatable kind checked previously") } @@ -389,22 +400,40 @@ mod llvm_enzyme { delim: rustc_ast::token::Delimiter::Parenthesis, tokens: ts, }); + let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span); - let d_annotatable = if is_impl { - let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf); - let d_fn = P(ast::AssocItem { - attrs: thin_vec![d_attr, inline_never], - id: ast::DUMMY_NODE_ID, - span, - vis, - kind: assoc_item, - tokens: None, - }); - Annotatable::AssocItem(d_fn, Impl { of_trait: false }) - } else { - let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf)); - d_fn.vis = vis; - Annotatable::Item(d_fn) + let d_annotatable = match &item { + Annotatable::AssocItem(_, _) => { + let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf); + let d_fn = P(ast::AssocItem { + attrs: thin_vec![d_attr, inline_never], + id: ast::DUMMY_NODE_ID, + span, + vis, + kind: assoc_item, + tokens: None, + }); + Annotatable::AssocItem(d_fn, Impl { of_trait: false }) + } + Annotatable::Item(_) => { + let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf)); + d_fn.vis = vis; + + Annotatable::Item(d_fn) + } + Annotatable::Stmt(_) => { + let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf)); + d_fn.vis = vis; + + Annotatable::Stmt(P(ast::Stmt { + id: ast::DUMMY_NODE_ID, + kind: ast::StmtKind::Item(d_fn), + span, + })) + } + _ => { + unreachable!("item kind checked previously") + } }; return vec![orig_annotatable, d_annotatable]; diff --git a/tests/pretty/autodiff_forward.pp b/tests/pretty/autodiff_forward.pp index 4b2fb6166ff7e..713b8f541ae0d 100644 --- a/tests/pretty/autodiff_forward.pp +++ b/tests/pretty/autodiff_forward.pp @@ -29,6 +29,8 @@ // Make sure, that we add the None for the default return. + // We want to make sure that we can use the macro for functions defined inside of functions + ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Dual, Const, Dual)] @@ -158,4 +160,25 @@ ::core::hint::black_box((bx_0,)); ::core::hint::black_box(::default()) } +pub fn f9() { + #[rustc_autodiff] + #[inline(never)] + fn inner(x: f32) -> f32 { x * x } + #[rustc_autodiff(Forward, 1, Dual, Dual)] + #[inline(never)] + fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) { + unsafe { asm!("NOP", options(pure, nomem)); }; + ::core::hint::black_box(inner(x)); + ::core::hint::black_box((bx_0,)); + ::core::hint::black_box(<(f32, f32)>::default()) + } + #[rustc_autodiff(Forward, 1, Dual, DualOnly)] + #[inline(never)] + fn d_inner_1(x: f32, bx_0: f32) -> f32 { + unsafe { asm!("NOP", options(pure, nomem)); }; + ::core::hint::black_box(inner(x)); + ::core::hint::black_box((bx_0,)); + ::core::hint::black_box(::default()) + } +} fn main() {} diff --git a/tests/pretty/autodiff_forward.rs b/tests/pretty/autodiff_forward.rs index a765738c2a815..5a0660a08e526 100644 --- a/tests/pretty/autodiff_forward.rs +++ b/tests/pretty/autodiff_forward.rs @@ -54,4 +54,13 @@ fn f8(x: &f32) -> f32 { unimplemented!() } +// We want to make sure that we can use the macro for functions defined inside of functions +pub fn f9() { + #[autodiff(d_inner_1, Forward, Dual, DualOnly)] + #[autodiff(d_inner_2, Forward, Dual, Dual)] + fn inner(x: f32) -> f32 { + x * x + } +} + fn main() {}