1mod llvm_enzyme {
7 use std::str::FromStr;
8 use std::string::String;
9
10 use rustc_ast::expand::autodiff_attrs::{
11 AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
12 valid_ty_for_activity,
13 };
14 use rustc_ast::token::{Lit, LitKind, Token, TokenKind};
15 use rustc_ast::tokenstream::*;
16 use rustc_ast::visit::AssocCtxt::*;
17 use rustc_ast::{
18 self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
19 MetaItemInner, PatKind, QSelf, TyKind, Visibility,
20 };
21 use rustc_expand::base::{Annotatable, ExtCtxt};
22 use rustc_span::{Ident, Span, Symbol, kw, sym};
23 use thin_vec::{ThinVec, thin_vec};
24 use tracing::{debug, trace};
25
26 use crate::errors;
27
28 pub(crate) fn outer_normal_attr(
29 kind: &Box<rustc_ast::NormalAttr>,
30 id: rustc_ast::AttrId,
31 span: Span,
32 ) -> rustc_ast::Attribute {
33 let style = rustc_ast::AttrStyle::Outer;
34 let kind = rustc_ast::AttrKind::Normal(kind.clone());
35 rustc_ast::Attribute { kind, id, style, span }
36 }
37
38 fn has_ret(ty: &FnRetTy) -> bool {
41 match ty {
42 FnRetTy::Ty(ty) => !ty.kind.is_unit(),
43 FnRetTy::Default(_) => false,
44 }
45 }
46 fn first_ident(x: &MetaItemInner) -> rustc_span::Ident {
47 if let Some(l) = x.lit() {
48 match l.kind {
49 ast::LitKind::Int(val, _) => {
50 return rustc_span::Ident::from_str(val.get().to_string().as_str());
52 }
53 _ => {}
54 }
55 }
56
57 let segments = &x.meta_item().unwrap().path.segments;
58 assert!(segments.len() == 1);
59 segments[0].ident
60 }
61
62 fn name(x: &MetaItemInner) -> String {
63 first_ident(x).name.to_string()
64 }
65
66 fn width(x: &MetaItemInner) -> Option<u128> {
67 let lit = x.lit()?;
68 match lit.kind {
69 ast::LitKind::Int(x, _) => Some(x.get()),
70 _ => return None,
71 }
72 }
73
74 fn extract_item_info(iitem: &Box<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> {
76 match &iitem.kind {
77 ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
78 Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
79 }
80 _ => None,
81 }
82 }
83
84 pub(crate) fn from_ast(
85 ecx: &mut ExtCtxt<'_>,
86 meta_item: &ThinVec<MetaItemInner>,
87 has_ret: bool,
88 mode: DiffMode,
89 ) -> AutoDiffAttrs {
90 let dcx = ecx.sess.dcx();
91
92 let mut first_activity = 1;
95
96 let width = if let [_, x, ..] = &meta_item[..]
97 && let Some(x) = width(x)
98 {
99 first_activity = 2;
100 match x.try_into() {
101 Ok(x) => x,
102 Err(_) => {
103 dcx.emit_err(errors::AutoDiffInvalidWidth {
104 span: meta_item[1].span(),
105 width: x,
106 });
107 return AutoDiffAttrs::error();
108 }
109 }
110 } else {
111 1
112 };
113
114 let mut activities: Vec<DiffActivity> = vec![];
115 let mut errors = false;
116 for x in &meta_item[first_activity..] {
117 let activity_str = name(&x);
118 let res = DiffActivity::from_str(&activity_str);
119 match res {
120 Ok(x) => activities.push(x),
121 Err(_) => {
122 dcx.emit_err(errors::AutoDiffUnknownActivity {
123 span: x.span(),
124 act: activity_str,
125 });
126 errors = true;
127 }
128 };
129 }
130 if errors {
131 return AutoDiffAttrs::error();
132 }
133
134 let (ret_activity, input_activity) = if has_ret {
137 let Some((last, rest)) = activities.split_last() else {
138 unreachable!(
139 "should not be reachable because we counted the number of activities previously"
140 );
141 };
142 (last, rest)
143 } else {
144 (&DiffActivity::None, activities.as_slice())
145 };
146
147 AutoDiffAttrs {
148 mode,
149 width,
150 ret_activity: *ret_activity,
151 input_activity: input_activity.to_vec(),
152 }
153 }
154
155 fn meta_item_inner_to_ts(t: &MetaItemInner, ts: &mut Vec<TokenTree>) {
156 let comma: Token = Token::new(TokenKind::Comma, Span::default());
157 let val = first_ident(t);
158 let t = Token::from_ast_ident(val);
159 ts.push(TokenTree::Token(t, Spacing::Joint));
160 ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
161 }
162
163 pub(crate) fn expand_forward(
164 ecx: &mut ExtCtxt<'_>,
165 expand_span: Span,
166 meta_item: &ast::MetaItem,
167 item: Annotatable,
168 ) -> Vec<Annotatable> {
169 expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward)
170 }
171
172 pub(crate) fn expand_reverse(
173 ecx: &mut ExtCtxt<'_>,
174 expand_span: Span,
175 meta_item: &ast::MetaItem,
176 item: Annotatable,
177 ) -> Vec<Annotatable> {
178 expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse)
179 }
180
181 pub(crate) fn expand_with_mode(
215 ecx: &mut ExtCtxt<'_>,
216 expand_span: Span,
217 meta_item: &ast::MetaItem,
218 mut item: Annotatable,
219 mode: DiffMode,
220 ) -> Vec<Annotatable> {
221 if cfg!(not(llvm_enzyme)) {
222 ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
223 return vec![item];
224 }
225 let dcx = ecx.sess.dcx();
226
227 let Some((vis, sig, primal, generics)) = (match &item {
231 Annotatable::Item(iitem) => extract_item_info(iitem),
232 Annotatable::Stmt(stmt) => match &stmt.kind {
233 ast::StmtKind::Item(iitem) => extract_item_info(iitem),
234 _ => None,
235 },
236 Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind {
237 ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
238 Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
239 }
240 _ => None,
241 },
242 _ => None,
243 }) else {
244 dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
245 return vec![item];
246 };
247
248 let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
249 ast::MetaItemKind::List(ref vec) => vec.clone(),
250 _ => {
251 dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
252 return vec![item];
253 }
254 };
255
256 let has_ret = has_ret(&sig.decl.output);
257 let sig_span = ecx.with_call_site_ctxt(sig.span);
258
259 let mut ts: Vec<TokenTree> = vec![];
262 if meta_item_vec.len() < 1 {
263 dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
265 return vec![item];
266 }
267
268 let mode_symbol = match mode {
269 DiffMode::Forward => sym::Forward,
270 DiffMode::Reverse => sym::Reverse,
271 _ => unreachable!("Unsupported mode: {:?}", mode),
272 };
273
274 let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default());
276 ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint));
277 ts.insert(
278 1,
279 TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone),
280 );
281
282 let start_position;
285 let kind: LitKind = LitKind::Integer;
286 let symbol;
287 if meta_item_vec.len() >= 2
288 && let Some(width) = width(&meta_item_vec[1])
289 {
290 start_position = 2;
291 symbol = Symbol::intern(&width.to_string());
292 } else {
293 start_position = 1;
294 symbol = sym::integer(1);
295 }
296
297 let l: Lit = Lit { kind, symbol, suffix: None };
298 let t = Token::new(TokenKind::Literal(l), Span::default());
299 let comma = Token::new(TokenKind::Comma, Span::default());
300 ts.push(TokenTree::Token(t, Spacing::Joint));
301 ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
302
303 for t in meta_item_vec.clone()[start_position..].iter() {
304 meta_item_inner_to_ts(t, &mut ts);
305 }
306
307 if !has_ret {
308 let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
311 ts.push(TokenTree::Token(t, Spacing::Joint));
312 ts.push(TokenTree::Token(comma, Spacing::Alone));
313 }
314 ts.pop();
316 let ts: TokenStream = TokenStream::from_iter(ts);
317
318 let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret, mode);
319 if !x.is_active() {
320 return vec![item];
323 }
324 let span = ecx.with_def_site_ctxt(expand_span);
325
326 let n_active: u32 = x
327 .input_activity
328 .iter()
329 .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
330 .count() as u32;
331 let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
332 let d_body = gen_enzyme_body(
333 ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
334 &generics,
335 );
336
337 let asdf = Box::new(ast::Fn {
339 defaultness: ast::Defaultness::Final,
340 sig: d_sig,
341 ident: first_ident(&meta_item_vec[0]),
342 generics,
343 contract: None,
344 body: Some(d_body),
345 define_opaque: None,
346 });
347 let mut rustc_ad_attr =
348 Box::new(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
349
350 let ts2: Vec<TokenTree> = vec![TokenTree::Token(
351 Token::new(TokenKind::Ident(sym::never, false.into()), span),
352 Spacing::Joint,
353 )];
354 let never_arg = ast::DelimArgs {
355 dspan: DelimSpan::from_single(span),
356 delim: ast::token::Delimiter::Parenthesis,
357 tokens: TokenStream::from_iter(ts2),
358 };
359 let inline_item = ast::AttrItem {
360 unsafety: ast::Safety::Default,
361 path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)),
362 args: ast::AttrArgs::Delimited(never_arg),
363 tokens: None,
364 };
365 let inline_never_attr = Box::new(ast::NormalAttr { item: inline_item, tokens: None });
366 let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
367 let attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
368 let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
369 let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);
370
371 fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
373 match (attr, item) {
374 (ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => {
375 let a = &a.item.path;
376 let b = &b.item.path;
377 a.segments.len() == b.segments.len()
378 && a.segments.iter().zip(b.segments.iter()).all(|(a, b)| a.ident == b.ident)
379 }
380 _ => false,
381 }
382 }
383
384 let orig_annotatable: Annotatable = match item {
386 Annotatable::Item(ref mut iitem) => {
387 if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
388 iitem.attrs.push(attr);
389 }
390 if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
391 iitem.attrs.push(inline_never.clone());
392 }
393 Annotatable::Item(iitem.clone())
394 }
395 Annotatable::AssocItem(ref mut assoc_item, i @ Impl { .. }) => {
396 if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
397 assoc_item.attrs.push(attr);
398 }
399 if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
400 assoc_item.attrs.push(inline_never.clone());
401 }
402 Annotatable::AssocItem(assoc_item.clone(), i)
403 }
404 Annotatable::Stmt(ref mut stmt) => {
405 match stmt.kind {
406 ast::StmtKind::Item(ref mut iitem) => {
407 if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
408 iitem.attrs.push(attr);
409 }
410 if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind))
411 {
412 iitem.attrs.push(inline_never.clone());
413 }
414 }
415 _ => unreachable!("stmt kind checked previously"),
416 };
417
418 Annotatable::Stmt(stmt.clone())
419 }
420 _ => {
421 unreachable!("annotatable kind checked previously")
422 }
423 };
424 rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {
426 dspan: DelimSpan::dummy(),
427 delim: rustc_ast::token::Delimiter::Parenthesis,
428 tokens: ts,
429 });
430
431 let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
432 let d_annotatable = match &item {
433 Annotatable::AssocItem(_, _) => {
434 let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
435 let d_fn = Box::new(ast::AssocItem {
436 attrs: thin_vec![d_attr, inline_never],
437 id: ast::DUMMY_NODE_ID,
438 span,
439 vis,
440 kind: assoc_item,
441 tokens: None,
442 });
443 Annotatable::AssocItem(d_fn, Impl { of_trait: false })
444 }
445 Annotatable::Item(_) => {
446 let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
447 d_fn.vis = vis;
448
449 Annotatable::Item(d_fn)
450 }
451 Annotatable::Stmt(_) => {
452 let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
453 d_fn.vis = vis;
454
455 Annotatable::Stmt(Box::new(ast::Stmt {
456 id: ast::DUMMY_NODE_ID,
457 kind: ast::StmtKind::Item(d_fn),
458 span,
459 }))
460 }
461 _ => {
462 unreachable!("item kind checked previously")
463 }
464 };
465
466 return vec![orig_annotatable, d_annotatable];
467 }
468
469 fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
472 let mut ty = ty.clone();
473 match ty.kind {
474 TyKind::Ptr(ref mut mut_ty) => {
475 mut_ty.mutbl = ast::Mutability::Mut;
476 }
477 TyKind::Ref(_, ref mut mut_ty) => {
478 mut_ty.mutbl = ast::Mutability::Mut;
479 }
480 _ => {
481 panic!("unsupported type: {:?}", ty);
482 }
483 }
484 ty
485 }
486
487 fn init_body_helper(
499 ecx: &ExtCtxt<'_>,
500 span: Span,
501 primal: Ident,
502 new_names: &[String],
503 sig_span: Span,
504 new_decl_span: Span,
505 idents: &[Ident],
506 errored: bool,
507 generics: &Generics,
508 ) -> (Box<ast::Block>, Box<ast::Expr>, Box<ast::Expr>, Box<ast::Expr>) {
509 let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
510 let noop = ast::InlineAsm {
511 asm_macro: ast::AsmMacro::Asm,
512 template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())],
513 template_strs: Box::new([]),
514 operands: vec![],
515 clobber_abis: vec![],
516 options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM,
517 line_spans: vec![],
518 };
519 let noop_expr = ecx.expr_asm(span, Box::new(noop));
520 let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated);
521 let unsf_block = ast::Block {
522 stmts: thin_vec![ecx.stmt_semi(noop_expr)],
523 id: ast::DUMMY_NODE_ID,
524 tokens: None,
525 rules: unsf,
526 span,
527 };
528 let unsf_expr = ecx.expr_block(Box::new(unsf_block));
529 let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
530 let primal_call = gen_primal_call(ecx, span, primal, idents, generics);
531 let black_box_primal_call = ecx.expr_call(
532 new_decl_span,
533 blackbox_call_expr.clone(),
534 thin_vec![primal_call.clone()],
535 );
536 let tup_args = new_names
537 .iter()
538 .map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg))))
539 .collect();
540
541 let black_box_remaining_args = ecx.expr_call(
542 sig_span,
543 blackbox_call_expr.clone(),
544 thin_vec![ecx.expr_tuple(sig_span, tup_args)],
545 );
546
547 let mut body = ecx.block(span, ThinVec::new());
548 body.stmts.push(ecx.stmt_semi(unsf_expr));
549
550 if !errored {
552 body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone()));
553 }
554 body.stmts.push(ecx.stmt_semi(black_box_remaining_args));
555
556 (body, primal_call, black_box_primal_call, blackbox_call_expr)
557 }
558
559 fn gen_enzyme_body(
568 ecx: &ExtCtxt<'_>,
569 x: &AutoDiffAttrs,
570 n_active: u32,
571 sig: &ast::FnSig,
572 d_sig: &ast::FnSig,
573 primal: Ident,
574 new_names: &[String],
575 span: Span,
576 sig_span: Span,
577 idents: Vec<Ident>,
578 errored: bool,
579 generics: &Generics,
580 ) -> Box<ast::Block> {
581 let new_decl_span = d_sig.span;
582
583 let (mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper(
592 ecx,
593 span,
594 primal,
595 new_names,
596 sig_span,
597 new_decl_span,
598 &idents,
599 errored,
600 generics,
601 );
602
603 if !has_ret(&d_sig.decl.output) {
604 return body;
606 }
607
608 let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
613
614 if primal_ret && n_active == 0 && x.mode.is_rev() {
615 body.stmts.push(ecx.stmt_expr(bb_primal_call));
617 return body;
618 }
619
620 if !primal_ret && n_active == 1 {
621 let ty = match d_sig.decl.output {
623 FnRetTy::Ty(ref ty) => ty.clone(),
624 FnRetTy::Default(span) => {
625 panic!("Did not expect Default ret ty: {:?}", span);
626 }
627 };
628 let arg = ty.kind.is_simple_path().unwrap();
629 let tmp = ecx.def_site_path(&[arg, kw::Default]);
630 let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
631 let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
632 body.stmts.push(ecx.stmt_expr(default_call_expr));
633 return body;
634 }
635
636 let mut exprs: Box<ast::Expr> = primal_call;
637 let d_ret_ty = match d_sig.decl.output {
638 FnRetTy::Ty(ref ty) => ty.clone(),
639 FnRetTy::Default(span) => {
640 panic!("Did not expect Default ret ty: {:?}", span);
641 }
642 };
643 if x.mode.is_fwd() {
644 if x.ret_activity == DiffActivity::Const {
649 exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
652 } else {
653 let q = QSelf { ty: d_ret_ty, path_span: span, position: 0 };
654 let y = ExprKind::Path(
655 Some(Box::new(q)),
656 ecx.path_ident(span, Ident::with_dummy_span(kw::Default)),
657 );
658 let default_call_expr = ecx.expr(span, y);
659 let default_call_expr =
660 ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
661 exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![default_call_expr]);
662 }
663 } else if x.mode.is_rev() {
664 if x.width == 1 {
665 match d_ret_ty.kind {
667 TyKind::Tup(ref args) => {
668 let mut exprs2 = thin_vec![exprs];
671 for arg in args.iter().skip(1) {
672 let arg = arg.kind.is_simple_path().unwrap();
673 let tmp = ecx.def_site_path(&[arg, kw::Default]);
674 let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
675 let default_call_expr =
676 ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
677 exprs2.push(default_call_expr);
678 }
679 exprs = ecx.expr_tuple(new_decl_span, exprs2);
680 }
681 _ => {
682 panic!("Unsupported return type: {:?}", d_ret_ty);
686 }
687 }
688 }
689 exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
690 } else {
691 unreachable!("Unsupported mode: {:?}", x.mode);
692 }
693
694 body.stmts.push(ecx.stmt_expr(exprs));
695
696 body
697 }
698
699 fn gen_primal_call(
700 ecx: &ExtCtxt<'_>,
701 span: Span,
702 primal: Ident,
703 idents: &[Ident],
704 generics: &Generics,
705 ) -> Box<ast::Expr> {
706 let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
707
708 if has_self {
709 let args: ThinVec<_> =
710 idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
711 let self_expr = ecx.expr_self(span);
712 ecx.expr_method_call(span, self_expr, primal, args)
713 } else {
714 let args: ThinVec<_> =
715 idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
716 let mut primal_path = ecx.path_ident(span, primal);
717
718 let is_generic = !generics.params.is_empty();
719
720 match (is_generic, primal_path.segments.last_mut()) {
721 (true, Some(function_path)) => {
722 let primal_generic_types = generics
723 .params
724 .iter()
725 .filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. }));
726
727 let generated_generic_types = primal_generic_types
728 .map(|type_param| {
729 let generic_param = TyKind::Path(
730 None,
731 ast::Path {
732 span,
733 segments: thin_vec![ast::PathSegment {
734 ident: type_param.ident,
735 args: None,
736 id: ast::DUMMY_NODE_ID,
737 }],
738 tokens: None,
739 },
740 );
741
742 ast::AngleBracketedArg::Arg(ast::GenericArg::Type(Box::new(ast::Ty {
743 id: type_param.id,
744 span,
745 kind: generic_param,
746 tokens: None,
747 })))
748 })
749 .collect();
750
751 function_path.args =
752 Some(Box::new(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs {
753 span,
754 args: generated_generic_types,
755 })));
756 }
757 _ => {}
758 }
759
760 let primal_call_expr = ecx.expr_path(primal_path);
761 ecx.expr_call(span, primal_call_expr, args)
762 }
763 }
764
765 fn gen_enzyme_decl(
777 ecx: &ExtCtxt<'_>,
778 sig: &ast::FnSig,
779 x: &AutoDiffAttrs,
780 span: Span,
781 ) -> (ast::FnSig, Vec<String>, Vec<Ident>, bool) {
782 let dcx = ecx.sess.dcx();
783 let has_ret = has_ret(&sig.decl.output);
784 let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };
785 let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 };
786 if sig_args != num_activities {
787 dcx.emit_err(errors::AutoDiffInvalidNumberActivities {
788 span,
789 expected: sig_args,
790 found: num_activities,
791 });
792 return (sig.clone(), vec![], vec![], true);
794 }
795 assert!(sig.decl.inputs.len() == x.input_activity.len());
796 assert!(has_ret == x.has_ret_activity());
797 let mut d_decl = sig.decl.clone();
798 let mut d_inputs = Vec::new();
799 let mut new_inputs = Vec::new();
800 let mut idents = Vec::new();
801 let mut act_ret = ThinVec::new();
802
803 let mut errors = false;
806 for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
807 if !valid_input_activity(x.mode, *activity) {
808 dcx.emit_err(errors::AutoDiffInvalidApplicationModeAct {
809 span,
810 mode: x.mode.to_string(),
811 act: activity.to_string(),
812 });
813 errors = true;
814 }
815 if !valid_ty_for_activity(&arg.ty, *activity) {
816 dcx.emit_err(errors::AutoDiffInvalidTypeForActivity {
817 span: arg.ty.span,
818 act: activity.to_string(),
819 });
820 errors = true;
821 }
822 }
823
824 if has_ret && !valid_ret_activity(x.mode, x.ret_activity) {
825 dcx.emit_err(errors::AutoDiffInvalidRetAct {
826 span,
827 mode: x.mode.to_string(),
828 act: x.ret_activity.to_string(),
829 });
830 }
833
834 if errors {
835 return (sig.clone(), new_inputs, idents, true);
837 }
838
839 let unsafe_activities = x
840 .input_activity
841 .iter()
842 .any(|&act| matches!(act, DiffActivity::DuplicatedOnly | DiffActivity::DualOnly));
843 for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
844 d_inputs.push(arg.clone());
845 match activity {
846 DiffActivity::Active => {
847 act_ret.push(arg.ty.clone());
848 }
850 DiffActivity::ActiveOnly => {
851 }
854 DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
855 for i in 0..x.width {
856 let mut shadow_arg = arg.clone();
857 shadow_arg.ty = Box::new(assure_mut_ref(&arg.ty));
859 let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
860 ident.name
861 } else {
862 debug!("{:#?}", &shadow_arg.pat);
863 panic!("not an ident?");
864 };
865 let name: String = format!("d{}_{}", old_name, i);
866 new_inputs.push(name.clone());
867 let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
868 shadow_arg.pat = Box::new(ast::Pat {
869 id: ast::DUMMY_NODE_ID,
870 kind: PatKind::Ident(BindingMode::NONE, ident, None),
871 span: shadow_arg.pat.span,
872 tokens: shadow_arg.pat.tokens.clone(),
873 });
874 d_inputs.push(shadow_arg.clone());
875 }
876 }
877 DiffActivity::Dual
878 | DiffActivity::DualOnly
879 | DiffActivity::Dualv
880 | DiffActivity::DualvOnly => {
881 let iterations =
884 if matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) {
885 1
886 } else {
887 x.width
888 };
889 for i in 0..iterations {
890 let mut shadow_arg = arg.clone();
891 let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
892 ident.name
893 } else {
894 debug!("{:#?}", &shadow_arg.pat);
895 panic!("not an ident?");
896 };
897 let name: String = format!("b{}_{}", old_name, i);
898 new_inputs.push(name.clone());
899 let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
900 shadow_arg.pat = Box::new(ast::Pat {
901 id: ast::DUMMY_NODE_ID,
902 kind: PatKind::Ident(BindingMode::NONE, ident, None),
903 span: shadow_arg.pat.span,
904 tokens: shadow_arg.pat.tokens.clone(),
905 });
906 d_inputs.push(shadow_arg.clone());
907 }
908 }
909 DiffActivity::Const => {
910 }
912 DiffActivity::None | DiffActivity::FakeActivitySize(_) => {
913 panic!("Should not happen");
914 }
915 }
916 if let PatKind::Ident(_, ident, _) = arg.pat.kind {
917 idents.push(ident.clone());
918 } else {
919 panic!("not an ident?");
920 }
921 }
922
923 let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly;
924 if active_only_ret {
925 assert!(x.mode.is_rev());
926 }
927
928 if x.mode.is_rev() {
931 match x.ret_activity {
932 DiffActivity::Active | DiffActivity::ActiveOnly => {
933 let ty = match d_decl.output {
934 FnRetTy::Ty(ref ty) => ty.clone(),
935 FnRetTy::Default(span) => {
936 panic!("Did not expect Default ret ty: {:?}", span);
937 }
938 };
939 let name = "dret".to_string();
940 let ident = Ident::from_str_and_span(&name, ty.span);
941 let shadow_arg = ast::Param {
942 attrs: ThinVec::new(),
943 ty: ty.clone(),
944 pat: Box::new(ast::Pat {
945 id: ast::DUMMY_NODE_ID,
946 kind: PatKind::Ident(BindingMode::NONE, ident, None),
947 span: ty.span,
948 tokens: None,
949 }),
950 id: ast::DUMMY_NODE_ID,
951 span: ty.span,
952 is_placeholder: false,
953 };
954 d_inputs.push(shadow_arg);
955 new_inputs.push(name);
956 }
957 _ => {}
958 }
959 }
960 d_decl.inputs = d_inputs.into();
961
962 if x.mode.is_fwd() {
963 let ty = match d_decl.output {
964 FnRetTy::Ty(ref ty) => ty.clone(),
965 FnRetTy::Default(span) => {
966 let kind = TyKind::Tup(ThinVec::new());
968 let ty = Box::new(rustc_ast::Ty {
969 kind,
970 id: ast::DUMMY_NODE_ID,
971 span,
972 tokens: None,
973 });
974 d_decl.output = FnRetTy::Ty(ty.clone());
975 assert!(matches!(x.ret_activity, DiffActivity::None));
976 ty
978 }
979 };
980
981 if matches!(x.ret_activity, DiffActivity::Dual | DiffActivity::Dualv) {
982 let kind = if x.width == 1 || matches!(x.ret_activity, DiffActivity::Dualv) {
983 TyKind::Tup(thin_vec![ty.clone(), ty.clone()])
986 } else {
987 let anon_const = rustc_ast::AnonConst {
989 id: ast::DUMMY_NODE_ID,
990 value: ecx.expr_usize(span, 1 + x.width as usize),
991 };
992 TyKind::Array(ty.clone(), anon_const)
993 };
994 let ty = Box::new(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
995 d_decl.output = FnRetTy::Ty(ty);
996 }
997 if matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) {
998 if x.width > 1 {
1002 let anon_const = rustc_ast::AnonConst {
1003 id: ast::DUMMY_NODE_ID,
1004 value: ecx.expr_usize(span, x.width as usize),
1005 };
1006 let kind = TyKind::Array(ty.clone(), anon_const);
1007 let ty =
1008 Box::new(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
1009 d_decl.output = FnRetTy::Ty(ty);
1010 }
1011 }
1012 }
1013
1014 d_decl.output =
1016 if active_only_ret { FnRetTy::Default(span) } else { d_decl.output.clone() };
1017
1018 trace!("act_ret: {:?}", act_ret);
1019
1020 if act_ret.len() > 0 {
1024 let ret_ty = match d_decl.output {
1025 FnRetTy::Ty(ref ty) => {
1026 if !active_only_ret {
1027 act_ret.insert(0, ty.clone());
1028 }
1029 let kind = TyKind::Tup(act_ret);
1030 Box::new(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None })
1031 }
1032 FnRetTy::Default(span) => {
1033 if act_ret.len() == 1 {
1034 act_ret[0].clone()
1035 } else {
1036 let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect());
1037 Box::new(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None })
1038 }
1039 }
1040 };
1041 d_decl.output = FnRetTy::Ty(ret_ty);
1042 }
1043
1044 let mut d_header = sig.header.clone();
1045 if unsafe_activities {
1046 d_header.safety = rustc_ast::Safety::Unsafe(span);
1047 }
1048 let d_sig = FnSig { header: d_header, decl: d_decl, span };
1049 trace!("Generated signature: {:?}", d_sig);
1050 (d_sig, new_inputs, idents, false)
1051 }
1052}
1053
1054pub(crate) use llvm_enzyme::{expand_forward, expand_reverse};