rustc_builtin_macros/
autodiff.rs

1//! This module contains the implementation of the `#[autodiff]` attribute.
2//! Currently our linter isn't smart enough to see that each import is used in one of the two
3//! configs (autodiff enabled or disabled), so we have to add cfg's to each import.
4//! FIXME(ZuseZ4): Remove this once we have a smarter linter.
5
6mod llvm_enzyme {
7    use std::str::FromStr;
8    use std::string::String;
9
10    use rustc_ast::expand::autodiff_attrs::{
11        AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
12        valid_ty_for_activity,
13    };
14    use rustc_ast::token::{Lit, LitKind, Token, TokenKind};
15    use rustc_ast::tokenstream::*;
16    use rustc_ast::visit::AssocCtxt::*;
17    use rustc_ast::{
18        self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
19        MetaItemInner, PatKind, QSelf, TyKind, Visibility,
20    };
21    use rustc_expand::base::{Annotatable, ExtCtxt};
22    use rustc_span::{Ident, Span, Symbol, kw, sym};
23    use thin_vec::{ThinVec, thin_vec};
24    use tracing::{debug, trace};
25
26    use crate::errors;
27
28    pub(crate) fn outer_normal_attr(
29        kind: &Box<rustc_ast::NormalAttr>,
30        id: rustc_ast::AttrId,
31        span: Span,
32    ) -> rustc_ast::Attribute {
33        let style = rustc_ast::AttrStyle::Outer;
34        let kind = rustc_ast::AttrKind::Normal(kind.clone());
35        rustc_ast::Attribute { kind, id, style, span }
36    }
37
38    // If we have a default `()` return type or explicitley `()` return type,
39    // then we often can skip doing some work.
40    fn has_ret(ty: &FnRetTy) -> bool {
41        match ty {
42            FnRetTy::Ty(ty) => !ty.kind.is_unit(),
43            FnRetTy::Default(_) => false,
44        }
45    }
46    fn first_ident(x: &MetaItemInner) -> rustc_span::Ident {
47        if let Some(l) = x.lit() {
48            match l.kind {
49                ast::LitKind::Int(val, _) => {
50                    // get an Ident from a lit
51                    return rustc_span::Ident::from_str(val.get().to_string().as_str());
52                }
53                _ => {}
54            }
55        }
56
57        let segments = &x.meta_item().unwrap().path.segments;
58        assert!(segments.len() == 1);
59        segments[0].ident
60    }
61
62    fn name(x: &MetaItemInner) -> String {
63        first_ident(x).name.to_string()
64    }
65
66    fn width(x: &MetaItemInner) -> Option<u128> {
67        let lit = x.lit()?;
68        match lit.kind {
69            ast::LitKind::Int(x, _) => Some(x.get()),
70            _ => return None,
71        }
72    }
73
74    // Get information about the function the macro is applied to
75    fn extract_item_info(iitem: &Box<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> {
76        match &iitem.kind {
77            ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
78                Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
79            }
80            _ => None,
81        }
82    }
83
84    pub(crate) fn from_ast(
85        ecx: &mut ExtCtxt<'_>,
86        meta_item: &ThinVec<MetaItemInner>,
87        has_ret: bool,
88        mode: DiffMode,
89    ) -> AutoDiffAttrs {
90        let dcx = ecx.sess.dcx();
91
92        // Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
93        // If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
94        let mut first_activity = 1;
95
96        let width = if let [_, x, ..] = &meta_item[..]
97            && let Some(x) = width(x)
98        {
99            first_activity = 2;
100            match x.try_into() {
101                Ok(x) => x,
102                Err(_) => {
103                    dcx.emit_err(errors::AutoDiffInvalidWidth {
104                        span: meta_item[1].span(),
105                        width: x,
106                    });
107                    return AutoDiffAttrs::error();
108                }
109            }
110        } else {
111            1
112        };
113
114        let mut activities: Vec<DiffActivity> = vec![];
115        let mut errors = false;
116        for x in &meta_item[first_activity..] {
117            let activity_str = name(&x);
118            let res = DiffActivity::from_str(&activity_str);
119            match res {
120                Ok(x) => activities.push(x),
121                Err(_) => {
122                    dcx.emit_err(errors::AutoDiffUnknownActivity {
123                        span: x.span(),
124                        act: activity_str,
125                    });
126                    errors = true;
127                }
128            };
129        }
130        if errors {
131            return AutoDiffAttrs::error();
132        }
133
134        // If a return type exist, we need to split the last activity,
135        // otherwise we return None as placeholder.
136        let (ret_activity, input_activity) = if has_ret {
137            let Some((last, rest)) = activities.split_last() else {
138                unreachable!(
139                    "should not be reachable because we counted the number of activities previously"
140                );
141            };
142            (last, rest)
143        } else {
144            (&DiffActivity::None, activities.as_slice())
145        };
146
147        AutoDiffAttrs {
148            mode,
149            width,
150            ret_activity: *ret_activity,
151            input_activity: input_activity.to_vec(),
152        }
153    }
154
155    fn meta_item_inner_to_ts(t: &MetaItemInner, ts: &mut Vec<TokenTree>) {
156        let comma: Token = Token::new(TokenKind::Comma, Span::default());
157        let val = first_ident(t);
158        let t = Token::from_ast_ident(val);
159        ts.push(TokenTree::Token(t, Spacing::Joint));
160        ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
161    }
162
163    pub(crate) fn expand_forward(
164        ecx: &mut ExtCtxt<'_>,
165        expand_span: Span,
166        meta_item: &ast::MetaItem,
167        item: Annotatable,
168    ) -> Vec<Annotatable> {
169        expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward)
170    }
171
172    pub(crate) fn expand_reverse(
173        ecx: &mut ExtCtxt<'_>,
174        expand_span: Span,
175        meta_item: &ast::MetaItem,
176        item: Annotatable,
177    ) -> Vec<Annotatable> {
178        expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse)
179    }
180
181    /// We expand the autodiff macro to generate a new placeholder function which passes
182    /// type-checking and can be called by users. The function body of the placeholder function will
183    /// later be replaced on LLVM-IR level, so the design of the body is less important and for now
184    /// should just prevent early inlining and optimizations which alter the function signature.
185    /// The exact signature of the generated function depends on the configuration provided by the
186    /// user, but here is an example:
187    ///
188    /// ```
189    /// #[autodiff(cos_box, Reverse, Duplicated, Active)]
190    /// fn sin(x: &Box<f32>) -> f32 {
191    ///     f32::sin(**x)
192    /// }
193    /// ```
194    /// which becomes expanded to:
195    /// ```
196    /// #[rustc_autodiff]
197    /// #[inline(never)]
198    /// fn sin(x: &Box<f32>) -> f32 {
199    ///     f32::sin(**x)
200    /// }
201    /// #[rustc_autodiff(Reverse, Duplicated, Active)]
202    /// #[inline(never)]
203    /// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
204    ///     unsafe {
205    ///         asm!("NOP");
206    ///     };
207    ///     ::core::hint::black_box(sin(x));
208    ///     ::core::hint::black_box((dx, dret));
209    ///     ::core::hint::black_box(sin(x))
210    /// }
211    /// ```
212    /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
213    /// in CI.
214    pub(crate) fn expand_with_mode(
215        ecx: &mut ExtCtxt<'_>,
216        expand_span: Span,
217        meta_item: &ast::MetaItem,
218        mut item: Annotatable,
219        mode: DiffMode,
220    ) -> Vec<Annotatable> {
221        if cfg!(not(llvm_enzyme)) {
222            ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
223            return vec![item];
224        }
225        let dcx = ecx.sess.dcx();
226
227        // first get information about the annotable item: visibility, signature, name and generic
228        // parameters.
229        // these will be used to generate the differentiated version of the function
230        let Some((vis, sig, primal, generics)) = (match &item {
231            Annotatable::Item(iitem) => extract_item_info(iitem),
232            Annotatable::Stmt(stmt) => match &stmt.kind {
233                ast::StmtKind::Item(iitem) => extract_item_info(iitem),
234                _ => None,
235            },
236            Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind {
237                ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
238                    Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
239                }
240                _ => None,
241            },
242            _ => None,
243        }) else {
244            dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
245            return vec![item];
246        };
247
248        let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
249            ast::MetaItemKind::List(ref vec) => vec.clone(),
250            _ => {
251                dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
252                return vec![item];
253            }
254        };
255
256        let has_ret = has_ret(&sig.decl.output);
257        let sig_span = ecx.with_call_site_ctxt(sig.span);
258
259        // create TokenStream from vec elemtents:
260        // meta_item doesn't have a .tokens field
261        let mut ts: Vec<TokenTree> = vec![];
262        if meta_item_vec.len() < 1 {
263            // At the bare minimum, we need a fnc name.
264            dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
265            return vec![item];
266        }
267
268        let mode_symbol = match mode {
269            DiffMode::Forward => sym::Forward,
270            DiffMode::Reverse => sym::Reverse,
271            _ => unreachable!("Unsupported mode: {:?}", mode),
272        };
273
274        // Insert mode token
275        let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default());
276        ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint));
277        ts.insert(
278            1,
279            TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone),
280        );
281
282        // Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
283        // If it is not given, we default to 1 (scalar mode).
284        let start_position;
285        let kind: LitKind = LitKind::Integer;
286        let symbol;
287        if meta_item_vec.len() >= 2
288            && let Some(width) = width(&meta_item_vec[1])
289        {
290            start_position = 2;
291            symbol = Symbol::intern(&width.to_string());
292        } else {
293            start_position = 1;
294            symbol = sym::integer(1);
295        }
296
297        let l: Lit = Lit { kind, symbol, suffix: None };
298        let t = Token::new(TokenKind::Literal(l), Span::default());
299        let comma = Token::new(TokenKind::Comma, Span::default());
300        ts.push(TokenTree::Token(t, Spacing::Joint));
301        ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
302
303        for t in meta_item_vec.clone()[start_position..].iter() {
304            meta_item_inner_to_ts(t, &mut ts);
305        }
306
307        if !has_ret {
308            // We don't want users to provide a return activity if the function doesn't return anything.
309            // For simplicity, we just add a dummy token to the end of the list.
310            let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
311            ts.push(TokenTree::Token(t, Spacing::Joint));
312            ts.push(TokenTree::Token(comma, Spacing::Alone));
313        }
314        // We remove the last, trailing comma.
315        ts.pop();
316        let ts: TokenStream = TokenStream::from_iter(ts);
317
318        let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret, mode);
319        if !x.is_active() {
320            // We encountered an error, so we return the original item.
321            // This allows us to potentially parse other attributes.
322            return vec![item];
323        }
324        let span = ecx.with_def_site_ctxt(expand_span);
325
326        let n_active: u32 = x
327            .input_activity
328            .iter()
329            .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
330            .count() as u32;
331        let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
332        let d_body = gen_enzyme_body(
333            ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
334            &generics,
335        );
336
337        // The first element of it is the name of the function to be generated
338        let asdf = Box::new(ast::Fn {
339            defaultness: ast::Defaultness::Final,
340            sig: d_sig,
341            ident: first_ident(&meta_item_vec[0]),
342            generics,
343            contract: None,
344            body: Some(d_body),
345            define_opaque: None,
346        });
347        let mut rustc_ad_attr =
348            Box::new(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
349
350        let ts2: Vec<TokenTree> = vec![TokenTree::Token(
351            Token::new(TokenKind::Ident(sym::never, false.into()), span),
352            Spacing::Joint,
353        )];
354        let never_arg = ast::DelimArgs {
355            dspan: DelimSpan::from_single(span),
356            delim: ast::token::Delimiter::Parenthesis,
357            tokens: TokenStream::from_iter(ts2),
358        };
359        let inline_item = ast::AttrItem {
360            unsafety: ast::Safety::Default,
361            path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)),
362            args: ast::AttrArgs::Delimited(never_arg),
363            tokens: None,
364        };
365        let inline_never_attr = Box::new(ast::NormalAttr { item: inline_item, tokens: None });
366        let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
367        let attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
368        let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
369        let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);
370
371        // We're avoid duplicating the attributes `#[rustc_autodiff]` and `#[inline(never)]`.
372        fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
373            match (attr, item) {
374                (ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => {
375                    let a = &a.item.path;
376                    let b = &b.item.path;
377                    a.segments.len() == b.segments.len()
378                        && a.segments.iter().zip(b.segments.iter()).all(|(a, b)| a.ident == b.ident)
379                }
380                _ => false,
381            }
382        }
383
384        // Don't add it multiple times:
385        let orig_annotatable: Annotatable = match item {
386            Annotatable::Item(ref mut iitem) => {
387                if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
388                    iitem.attrs.push(attr);
389                }
390                if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
391                    iitem.attrs.push(inline_never.clone());
392                }
393                Annotatable::Item(iitem.clone())
394            }
395            Annotatable::AssocItem(ref mut assoc_item, i @ Impl { .. }) => {
396                if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
397                    assoc_item.attrs.push(attr);
398                }
399                if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
400                    assoc_item.attrs.push(inline_never.clone());
401                }
402                Annotatable::AssocItem(assoc_item.clone(), i)
403            }
404            Annotatable::Stmt(ref mut stmt) => {
405                match stmt.kind {
406                    ast::StmtKind::Item(ref mut iitem) => {
407                        if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
408                            iitem.attrs.push(attr);
409                        }
410                        if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind))
411                        {
412                            iitem.attrs.push(inline_never.clone());
413                        }
414                    }
415                    _ => unreachable!("stmt kind checked previously"),
416                };
417
418                Annotatable::Stmt(stmt.clone())
419            }
420            _ => {
421                unreachable!("annotatable kind checked previously")
422            }
423        };
424        // Now update for d_fn
425        rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {
426            dspan: DelimSpan::dummy(),
427            delim: rustc_ast::token::Delimiter::Parenthesis,
428            tokens: ts,
429        });
430
431        let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
432        let d_annotatable = match &item {
433            Annotatable::AssocItem(_, _) => {
434                let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
435                let d_fn = Box::new(ast::AssocItem {
436                    attrs: thin_vec![d_attr, inline_never],
437                    id: ast::DUMMY_NODE_ID,
438                    span,
439                    vis,
440                    kind: assoc_item,
441                    tokens: None,
442                });
443                Annotatable::AssocItem(d_fn, Impl { of_trait: false })
444            }
445            Annotatable::Item(_) => {
446                let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
447                d_fn.vis = vis;
448
449                Annotatable::Item(d_fn)
450            }
451            Annotatable::Stmt(_) => {
452                let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
453                d_fn.vis = vis;
454
455                Annotatable::Stmt(Box::new(ast::Stmt {
456                    id: ast::DUMMY_NODE_ID,
457                    kind: ast::StmtKind::Item(d_fn),
458                    span,
459                }))
460            }
461            _ => {
462                unreachable!("item kind checked previously")
463            }
464        };
465
466        return vec![orig_annotatable, d_annotatable];
467    }
468
469    // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be
470    // mutable references or ptrs, because Enzyme will write into them.
471    fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
472        let mut ty = ty.clone();
473        match ty.kind {
474            TyKind::Ptr(ref mut mut_ty) => {
475                mut_ty.mutbl = ast::Mutability::Mut;
476            }
477            TyKind::Ref(_, ref mut mut_ty) => {
478                mut_ty.mutbl = ast::Mutability::Mut;
479            }
480            _ => {
481                panic!("unsupported type: {:?}", ty);
482            }
483        }
484        ty
485    }
486
487    // Will generate a body of the type:
488    // ```
489    // {
490    //   unsafe {
491    //   asm!("NOP");
492    //   }
493    //   ::core::hint::black_box(primal(args));
494    //   ::core::hint::black_box((args, ret));
495    //   <This part remains to be done by following function>
496    // }
497    // ```
498    fn init_body_helper(
499        ecx: &ExtCtxt<'_>,
500        span: Span,
501        primal: Ident,
502        new_names: &[String],
503        sig_span: Span,
504        new_decl_span: Span,
505        idents: &[Ident],
506        errored: bool,
507        generics: &Generics,
508    ) -> (Box<ast::Block>, Box<ast::Expr>, Box<ast::Expr>, Box<ast::Expr>) {
509        let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
510        let noop = ast::InlineAsm {
511            asm_macro: ast::AsmMacro::Asm,
512            template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())],
513            template_strs: Box::new([]),
514            operands: vec![],
515            clobber_abis: vec![],
516            options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM,
517            line_spans: vec![],
518        };
519        let noop_expr = ecx.expr_asm(span, Box::new(noop));
520        let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated);
521        let unsf_block = ast::Block {
522            stmts: thin_vec![ecx.stmt_semi(noop_expr)],
523            id: ast::DUMMY_NODE_ID,
524            tokens: None,
525            rules: unsf,
526            span,
527        };
528        let unsf_expr = ecx.expr_block(Box::new(unsf_block));
529        let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
530        let primal_call = gen_primal_call(ecx, span, primal, idents, generics);
531        let black_box_primal_call = ecx.expr_call(
532            new_decl_span,
533            blackbox_call_expr.clone(),
534            thin_vec![primal_call.clone()],
535        );
536        let tup_args = new_names
537            .iter()
538            .map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg))))
539            .collect();
540
541        let black_box_remaining_args = ecx.expr_call(
542            sig_span,
543            blackbox_call_expr.clone(),
544            thin_vec![ecx.expr_tuple(sig_span, tup_args)],
545        );
546
547        let mut body = ecx.block(span, ThinVec::new());
548        body.stmts.push(ecx.stmt_semi(unsf_expr));
549
550        // This uses primal args which won't be available if we errored before
551        if !errored {
552            body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone()));
553        }
554        body.stmts.push(ecx.stmt_semi(black_box_remaining_args));
555
556        (body, primal_call, black_box_primal_call, blackbox_call_expr)
557    }
558
559    /// We only want this function to type-check, since we will replace the body
560    /// later on llvm level. Using `loop {}` does not cover all return types anymore,
561    /// so instead we manually build something that should pass the type checker.
562    /// We also add a inline_asm line, as one more barrier for rustc to prevent inlining
563    /// or const propagation. inline_asm will also triggers an Enzyme crash if due to another
564    /// bug would ever try to accidentally differentiate this placeholder function body.
565    /// Finally, we also add back_box usages of all input arguments, to prevent rustc
566    /// from optimizing any arguments away.
567    fn gen_enzyme_body(
568        ecx: &ExtCtxt<'_>,
569        x: &AutoDiffAttrs,
570        n_active: u32,
571        sig: &ast::FnSig,
572        d_sig: &ast::FnSig,
573        primal: Ident,
574        new_names: &[String],
575        span: Span,
576        sig_span: Span,
577        idents: Vec<Ident>,
578        errored: bool,
579        generics: &Generics,
580    ) -> Box<ast::Block> {
581        let new_decl_span = d_sig.span;
582
583        // Just adding some default inline-asm and black_box usages to prevent early inlining
584        // and optimizations which alter the function signature.
585        //
586        // The bb_primal_call is the black_box call of the primal function. We keep it around,
587        // since it has the convenient property of returning the type of the primal function,
588        // Remember, we only care to match types here.
589        // No matter which return we pick, we always wrap it into a std::hint::black_box call,
590        // to prevent rustc from propagating it into the caller.
591        let (mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper(
592            ecx,
593            span,
594            primal,
595            new_names,
596            sig_span,
597            new_decl_span,
598            &idents,
599            errored,
600            generics,
601        );
602
603        if !has_ret(&d_sig.decl.output) {
604            // there is no return type that we have to match, () works fine.
605            return body;
606        }
607
608        // Everything from here onwards just tries to fulfil the return type. Fun!
609
610        // having an active-only return means we'll drop the original return type.
611        // So that can be treated identical to not having one in the first place.
612        let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
613
614        if primal_ret && n_active == 0 && x.mode.is_rev() {
615            // We only have the primal ret.
616            body.stmts.push(ecx.stmt_expr(bb_primal_call));
617            return body;
618        }
619
620        if !primal_ret && n_active == 1 {
621            // Again no tuple return, so return default float val.
622            let ty = match d_sig.decl.output {
623                FnRetTy::Ty(ref ty) => ty.clone(),
624                FnRetTy::Default(span) => {
625                    panic!("Did not expect Default ret ty: {:?}", span);
626                }
627            };
628            let arg = ty.kind.is_simple_path().unwrap();
629            let tmp = ecx.def_site_path(&[arg, kw::Default]);
630            let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
631            let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
632            body.stmts.push(ecx.stmt_expr(default_call_expr));
633            return body;
634        }
635
636        let mut exprs: Box<ast::Expr> = primal_call;
637        let d_ret_ty = match d_sig.decl.output {
638            FnRetTy::Ty(ref ty) => ty.clone(),
639            FnRetTy::Default(span) => {
640                panic!("Did not expect Default ret ty: {:?}", span);
641            }
642        };
643        if x.mode.is_fwd() {
644            // Fwd mode is easy. If the return activity is Const, we support arbitrary types.
645            // Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
646            // We checked that (on a best-effort base) in the preceding gen_enzyme_decl function.
647            // In all three cases, we can return `std::hint::black_box(<T>::default())`.
648            if x.ret_activity == DiffActivity::Const {
649                // Here we call the primal function, since our dummy function has the same return
650                // type due to the Const return activity.
651                exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
652            } else {
653                let q = QSelf { ty: d_ret_ty, path_span: span, position: 0 };
654                let y = ExprKind::Path(
655                    Some(Box::new(q)),
656                    ecx.path_ident(span, Ident::with_dummy_span(kw::Default)),
657                );
658                let default_call_expr = ecx.expr(span, y);
659                let default_call_expr =
660                    ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
661                exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![default_call_expr]);
662            }
663        } else if x.mode.is_rev() {
664            if x.width == 1 {
665                // We either have `-> ArbitraryType` or `-> (ArbitraryType, repeated_float_scalars)`.
666                match d_ret_ty.kind {
667                    TyKind::Tup(ref args) => {
668                        // We have a tuple return type. We need to create a tuple of the same size
669                        // and fill it with default values.
670                        let mut exprs2 = thin_vec![exprs];
671                        for arg in args.iter().skip(1) {
672                            let arg = arg.kind.is_simple_path().unwrap();
673                            let tmp = ecx.def_site_path(&[arg, kw::Default]);
674                            let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
675                            let default_call_expr =
676                                ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
677                            exprs2.push(default_call_expr);
678                        }
679                        exprs = ecx.expr_tuple(new_decl_span, exprs2);
680                    }
681                    _ => {
682                        // Interestingly, even the `-> ArbitraryType` case
683                        // ends up getting matched and handled correctly above,
684                        // so we don't have to handle any other case for now.
685                        panic!("Unsupported return type: {:?}", d_ret_ty);
686                    }
687                }
688            }
689            exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
690        } else {
691            unreachable!("Unsupported mode: {:?}", x.mode);
692        }
693
694        body.stmts.push(ecx.stmt_expr(exprs));
695
696        body
697    }
698
699    fn gen_primal_call(
700        ecx: &ExtCtxt<'_>,
701        span: Span,
702        primal: Ident,
703        idents: &[Ident],
704        generics: &Generics,
705    ) -> Box<ast::Expr> {
706        let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
707
708        if has_self {
709            let args: ThinVec<_> =
710                idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
711            let self_expr = ecx.expr_self(span);
712            ecx.expr_method_call(span, self_expr, primal, args)
713        } else {
714            let args: ThinVec<_> =
715                idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
716            let mut primal_path = ecx.path_ident(span, primal);
717
718            let is_generic = !generics.params.is_empty();
719
720            match (is_generic, primal_path.segments.last_mut()) {
721                (true, Some(function_path)) => {
722                    let primal_generic_types = generics
723                        .params
724                        .iter()
725                        .filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. }));
726
727                    let generated_generic_types = primal_generic_types
728                        .map(|type_param| {
729                            let generic_param = TyKind::Path(
730                                None,
731                                ast::Path {
732                                    span,
733                                    segments: thin_vec![ast::PathSegment {
734                                        ident: type_param.ident,
735                                        args: None,
736                                        id: ast::DUMMY_NODE_ID,
737                                    }],
738                                    tokens: None,
739                                },
740                            );
741
742                            ast::AngleBracketedArg::Arg(ast::GenericArg::Type(Box::new(ast::Ty {
743                                id: type_param.id,
744                                span,
745                                kind: generic_param,
746                                tokens: None,
747                            })))
748                        })
749                        .collect();
750
751                    function_path.args =
752                        Some(Box::new(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs {
753                            span,
754                            args: generated_generic_types,
755                        })));
756                }
757                _ => {}
758            }
759
760            let primal_call_expr = ecx.expr_path(primal_path);
761            ecx.expr_call(span, primal_call_expr, args)
762        }
763    }
764
765    // Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
766    // be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
767    // Active arguments must be scalars. Their shadow argument is added to the return type (and will be
768    // zero-initialized by Enzyme).
769    // Each argument of the primal function (and the return type if existing) must be annotated with an
770    // activity.
771    //
772    // Error handling: If the user provides an invalid configuration (incorrect numbers, types, or
773    // both), we emit an error and return the original signature. This allows us to continue parsing.
774    // FIXME(Sa4dUs): make individual activities' span available so errors
775    // can point to only the activity instead of the entire attribute
776    fn gen_enzyme_decl(
777        ecx: &ExtCtxt<'_>,
778        sig: &ast::FnSig,
779        x: &AutoDiffAttrs,
780        span: Span,
781    ) -> (ast::FnSig, Vec<String>, Vec<Ident>, bool) {
782        let dcx = ecx.sess.dcx();
783        let has_ret = has_ret(&sig.decl.output);
784        let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };
785        let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 };
786        if sig_args != num_activities {
787            dcx.emit_err(errors::AutoDiffInvalidNumberActivities {
788                span,
789                expected: sig_args,
790                found: num_activities,
791            });
792            // This is not the right signature, but we can continue parsing.
793            return (sig.clone(), vec![], vec![], true);
794        }
795        assert!(sig.decl.inputs.len() == x.input_activity.len());
796        assert!(has_ret == x.has_ret_activity());
797        let mut d_decl = sig.decl.clone();
798        let mut d_inputs = Vec::new();
799        let mut new_inputs = Vec::new();
800        let mut idents = Vec::new();
801        let mut act_ret = ThinVec::new();
802
803        // We have two loops, a first one just to check the activities and types and possibly report
804        // multiple errors in one compilation session.
805        let mut errors = false;
806        for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
807            if !valid_input_activity(x.mode, *activity) {
808                dcx.emit_err(errors::AutoDiffInvalidApplicationModeAct {
809                    span,
810                    mode: x.mode.to_string(),
811                    act: activity.to_string(),
812                });
813                errors = true;
814            }
815            if !valid_ty_for_activity(&arg.ty, *activity) {
816                dcx.emit_err(errors::AutoDiffInvalidTypeForActivity {
817                    span: arg.ty.span,
818                    act: activity.to_string(),
819                });
820                errors = true;
821            }
822        }
823
824        if has_ret && !valid_ret_activity(x.mode, x.ret_activity) {
825            dcx.emit_err(errors::AutoDiffInvalidRetAct {
826                span,
827                mode: x.mode.to_string(),
828                act: x.ret_activity.to_string(),
829            });
830            // We don't set `errors = true` to avoid annoying type errors relative
831            // to the expanded macro type signature
832        }
833
834        if errors {
835            // This is not the right signature, but we can continue parsing.
836            return (sig.clone(), new_inputs, idents, true);
837        }
838
839        let unsafe_activities = x
840            .input_activity
841            .iter()
842            .any(|&act| matches!(act, DiffActivity::DuplicatedOnly | DiffActivity::DualOnly));
843        for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
844            d_inputs.push(arg.clone());
845            match activity {
846                DiffActivity::Active => {
847                    act_ret.push(arg.ty.clone());
848                    // if width =/= 1, then push [arg.ty; width] to act_ret
849                }
850                DiffActivity::ActiveOnly => {
851                    // We will add the active scalar to the return type.
852                    // This is handled later.
853                }
854                DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
855                    for i in 0..x.width {
856                        let mut shadow_arg = arg.clone();
857                        // We += into the shadow in reverse mode.
858                        shadow_arg.ty = Box::new(assure_mut_ref(&arg.ty));
859                        let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
860                            ident.name
861                        } else {
862                            debug!("{:#?}", &shadow_arg.pat);
863                            panic!("not an ident?");
864                        };
865                        let name: String = format!("d{}_{}", old_name, i);
866                        new_inputs.push(name.clone());
867                        let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
868                        shadow_arg.pat = Box::new(ast::Pat {
869                            id: ast::DUMMY_NODE_ID,
870                            kind: PatKind::Ident(BindingMode::NONE, ident, None),
871                            span: shadow_arg.pat.span,
872                            tokens: shadow_arg.pat.tokens.clone(),
873                        });
874                        d_inputs.push(shadow_arg.clone());
875                    }
876                }
877                DiffActivity::Dual
878                | DiffActivity::DualOnly
879                | DiffActivity::Dualv
880                | DiffActivity::DualvOnly => {
881                    // the *v variants get lowered to enzyme_dupv and enzyme_dupnoneedv, which cause
882                    // Enzyme to not expect N arguments, but one argument (which is instead larger).
883                    let iterations =
884                        if matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) {
885                            1
886                        } else {
887                            x.width
888                        };
889                    for i in 0..iterations {
890                        let mut shadow_arg = arg.clone();
891                        let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
892                            ident.name
893                        } else {
894                            debug!("{:#?}", &shadow_arg.pat);
895                            panic!("not an ident?");
896                        };
897                        let name: String = format!("b{}_{}", old_name, i);
898                        new_inputs.push(name.clone());
899                        let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
900                        shadow_arg.pat = Box::new(ast::Pat {
901                            id: ast::DUMMY_NODE_ID,
902                            kind: PatKind::Ident(BindingMode::NONE, ident, None),
903                            span: shadow_arg.pat.span,
904                            tokens: shadow_arg.pat.tokens.clone(),
905                        });
906                        d_inputs.push(shadow_arg.clone());
907                    }
908                }
909                DiffActivity::Const => {
910                    // Nothing to do here.
911                }
912                DiffActivity::None | DiffActivity::FakeActivitySize(_) => {
913                    panic!("Should not happen");
914                }
915            }
916            if let PatKind::Ident(_, ident, _) = arg.pat.kind {
917                idents.push(ident.clone());
918            } else {
919                panic!("not an ident?");
920            }
921        }
922
923        let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly;
924        if active_only_ret {
925            assert!(x.mode.is_rev());
926        }
927
928        // If we return a scalar in the primal and the scalar is active,
929        // then add it as last arg to the inputs.
930        if x.mode.is_rev() {
931            match x.ret_activity {
932                DiffActivity::Active | DiffActivity::ActiveOnly => {
933                    let ty = match d_decl.output {
934                        FnRetTy::Ty(ref ty) => ty.clone(),
935                        FnRetTy::Default(span) => {
936                            panic!("Did not expect Default ret ty: {:?}", span);
937                        }
938                    };
939                    let name = "dret".to_string();
940                    let ident = Ident::from_str_and_span(&name, ty.span);
941                    let shadow_arg = ast::Param {
942                        attrs: ThinVec::new(),
943                        ty: ty.clone(),
944                        pat: Box::new(ast::Pat {
945                            id: ast::DUMMY_NODE_ID,
946                            kind: PatKind::Ident(BindingMode::NONE, ident, None),
947                            span: ty.span,
948                            tokens: None,
949                        }),
950                        id: ast::DUMMY_NODE_ID,
951                        span: ty.span,
952                        is_placeholder: false,
953                    };
954                    d_inputs.push(shadow_arg);
955                    new_inputs.push(name);
956                }
957                _ => {}
958            }
959        }
960        d_decl.inputs = d_inputs.into();
961
962        if x.mode.is_fwd() {
963            let ty = match d_decl.output {
964                FnRetTy::Ty(ref ty) => ty.clone(),
965                FnRetTy::Default(span) => {
966                    // We want to return std::hint::black_box(()).
967                    let kind = TyKind::Tup(ThinVec::new());
968                    let ty = Box::new(rustc_ast::Ty {
969                        kind,
970                        id: ast::DUMMY_NODE_ID,
971                        span,
972                        tokens: None,
973                    });
974                    d_decl.output = FnRetTy::Ty(ty.clone());
975                    assert!(matches!(x.ret_activity, DiffActivity::None));
976                    // this won't be used below, so any type would be fine.
977                    ty
978                }
979            };
980
981            if matches!(x.ret_activity, DiffActivity::Dual | DiffActivity::Dualv) {
982                let kind = if x.width == 1 || matches!(x.ret_activity, DiffActivity::Dualv) {
983                    // Dual can only be used for f32/f64 ret.
984                    // In that case we return now a tuple with two floats.
985                    TyKind::Tup(thin_vec![ty.clone(), ty.clone()])
986                } else {
987                    // We have to return [T; width+1], +1 for the primal return.
988                    let anon_const = rustc_ast::AnonConst {
989                        id: ast::DUMMY_NODE_ID,
990                        value: ecx.expr_usize(span, 1 + x.width as usize),
991                    };
992                    TyKind::Array(ty.clone(), anon_const)
993                };
994                let ty = Box::new(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
995                d_decl.output = FnRetTy::Ty(ty);
996            }
997            if matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) {
998                // No need to change the return type,
999                // we will just return the shadow in place of the primal return.
1000                // However, if we have a width > 1, then we don't return -> T, but -> [T; width]
1001                if x.width > 1 {
1002                    let anon_const = rustc_ast::AnonConst {
1003                        id: ast::DUMMY_NODE_ID,
1004                        value: ecx.expr_usize(span, x.width as usize),
1005                    };
1006                    let kind = TyKind::Array(ty.clone(), anon_const);
1007                    let ty =
1008                        Box::new(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
1009                    d_decl.output = FnRetTy::Ty(ty);
1010                }
1011            }
1012        }
1013
1014        // If we use ActiveOnly, drop the original return value.
1015        d_decl.output =
1016            if active_only_ret { FnRetTy::Default(span) } else { d_decl.output.clone() };
1017
1018        trace!("act_ret: {:?}", act_ret);
1019
1020        // If we have an active input scalar, add it's gradient to the
1021        // return type. This might require changing the return type to a
1022        // tuple.
1023        if act_ret.len() > 0 {
1024            let ret_ty = match d_decl.output {
1025                FnRetTy::Ty(ref ty) => {
1026                    if !active_only_ret {
1027                        act_ret.insert(0, ty.clone());
1028                    }
1029                    let kind = TyKind::Tup(act_ret);
1030                    Box::new(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None })
1031                }
1032                FnRetTy::Default(span) => {
1033                    if act_ret.len() == 1 {
1034                        act_ret[0].clone()
1035                    } else {
1036                        let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect());
1037                        Box::new(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None })
1038                    }
1039                }
1040            };
1041            d_decl.output = FnRetTy::Ty(ret_ty);
1042        }
1043
1044        let mut d_header = sig.header.clone();
1045        if unsafe_activities {
1046            d_header.safety = rustc_ast::Safety::Unsafe(span);
1047        }
1048        let d_sig = FnSig { header: d_header, decl: d_decl, span };
1049        trace!("Generated signature: {:?}", d_sig);
1050        (d_sig, new_inputs, idents, false)
1051    }
1052}
1053
1054pub(crate) use llvm_enzyme::{expand_forward, expand_reverse};