rustc_mir_build/
check_tail_calls.rs

1use rustc_abi::ExternAbi;
2use rustc_data_structures::stack::ensure_sufficient_stack;
3use rustc_errors::Applicability;
4use rustc_hir::LangItem;
5use rustc_hir::def::DefKind;
6use rustc_middle::span_bug;
7use rustc_middle::thir::visit::{self, Visitor};
8use rustc_middle::thir::{BodyTy, Expr, ExprId, ExprKind, Thir};
9use rustc_middle::ty::{self, Ty, TyCtxt};
10use rustc_span::def_id::{DefId, LocalDefId};
11use rustc_span::{DUMMY_SP, ErrorGuaranteed, Span};
12
13pub(crate) fn check_tail_calls(tcx: TyCtxt<'_>, def: LocalDefId) -> Result<(), ErrorGuaranteed> {
14    let (thir, expr) = tcx.thir_body(def)?;
15    let thir = &thir.borrow();
16
17    // If `thir` is empty, a type error occurred, skip this body.
18    if thir.exprs.is_empty() {
19        return Ok(());
20    }
21
22    let is_closure = matches!(tcx.def_kind(def), DefKind::Closure);
23    let caller_ty = tcx.type_of(def).skip_binder();
24
25    let mut visitor = TailCallCkVisitor {
26        tcx,
27        thir,
28        found_errors: Ok(()),
29        // FIXME(#132279): we're clearly in a body here.
30        typing_env: ty::TypingEnv::non_body_analysis(tcx, def),
31        is_closure,
32        caller_ty,
33    };
34
35    visitor.visit_expr(&thir[expr]);
36
37    visitor.found_errors
38}
39
40struct TailCallCkVisitor<'a, 'tcx> {
41    tcx: TyCtxt<'tcx>,
42    thir: &'a Thir<'tcx>,
43    typing_env: ty::TypingEnv<'tcx>,
44    /// Whatever the currently checked body is one of a closure
45    is_closure: bool,
46    /// The result of the checks, `Err(_)` if there was a problem with some
47    /// tail call, `Ok(())` if all of them were fine.
48    found_errors: Result<(), ErrorGuaranteed>,
49    /// Type of the caller function.
50    caller_ty: Ty<'tcx>,
51}
52
53impl<'tcx> TailCallCkVisitor<'_, 'tcx> {
54    fn check_tail_call(&mut self, call: &Expr<'_>, expr: &Expr<'_>) {
55        if self.is_closure {
56            self.report_in_closure(expr);
57            return;
58        }
59
60        let BodyTy::Fn(caller_sig) = self.thir.body_type else {
61            span_bug!(
62                call.span,
63                "`become` outside of functions should have been disallowed by hir_typeck"
64            )
65        };
66        // While the `caller_sig` does have its regions erased, it does not have its
67        // binders anonymized. We call `erase_regions` once again to anonymize any binders
68        // within the signature, such as in function pointer or `dyn Trait` args.
69        let caller_sig = self.tcx.erase_regions(caller_sig);
70
71        let ExprKind::Scope { value, .. } = call.kind else {
72            span_bug!(call.span, "expected scope, found: {call:?}")
73        };
74        let value = &self.thir[value];
75
76        if matches!(
77            value.kind,
78            ExprKind::Binary { .. }
79                | ExprKind::Unary { .. }
80                | ExprKind::AssignOp { .. }
81                | ExprKind::Index { .. }
82        ) {
83            self.report_builtin_op(call, expr);
84            return;
85        }
86
87        let ExprKind::Call { ty, fun, ref args, from_hir_call, fn_span } = value.kind else {
88            self.report_non_call(value, expr);
89            return;
90        };
91
92        if !from_hir_call {
93            self.report_op(ty, args, fn_span, expr);
94        }
95
96        if let &ty::FnDef(did, args) = ty.kind() {
97            // Closures in thir look something akin to
98            // `for<'a> extern "rust-call" fn(&'a [closure@...], ()) -> <[closure@...] as FnOnce<()>>::Output {<[closure@...] as Fn<()>>::call}`
99            // So we have to check for them in this weird way...
100            let parent = self.tcx.parent(did);
101            if self.tcx.fn_trait_kind_from_def_id(parent).is_some()
102                && let Some(this) = args.first()
103                && let Some(this) = this.as_type()
104            {
105                if this.is_closure() {
106                    self.report_calling_closure(&self.thir[fun], args[1].as_type().unwrap(), expr);
107                } else {
108                    // This can happen when tail calling `Box` that wraps a function
109                    self.report_nonfn_callee(fn_span, self.thir[fun].span, this);
110                }
111
112                // Tail calling is likely to cause unrelated errors (ABI, argument mismatches),
113                // skip them, producing an error about calling a closure is enough.
114                return;
115            };
116
117            if self.tcx.intrinsic(did).is_some() {
118                self.report_calling_intrinsic(expr);
119            }
120        }
121
122        let (ty::FnDef(..) | ty::FnPtr(..)) = ty.kind() else {
123            self.report_nonfn_callee(fn_span, self.thir[fun].span, ty);
124
125            // `fn_sig` below panics otherwise
126            return;
127        };
128
129        // Erase regions since tail calls don't care about lifetimes
130        let callee_sig =
131            self.tcx.normalize_erasing_late_bound_regions(self.typing_env, ty.fn_sig(self.tcx));
132
133        if caller_sig.abi != callee_sig.abi {
134            self.report_abi_mismatch(expr.span, caller_sig.abi, callee_sig.abi);
135        }
136
137        if caller_sig.inputs_and_output != callee_sig.inputs_and_output {
138            if caller_sig.inputs() != callee_sig.inputs() {
139                self.report_arguments_mismatch(expr.span, caller_sig, callee_sig);
140            }
141
142            // FIXME(explicit_tail_calls): this currently fails for cases where opaques are used.
143            // e.g.
144            // ```
145            // fn a() -> impl Sized { become b() } // ICE
146            // fn b() -> u8 { 0 }
147            // ```
148            // we should think what is the expected behavior here.
149            // (we should probably just accept this by revealing opaques?)
150            if caller_sig.output() != callee_sig.output() {
151                span_bug!(expr.span, "hir typeck should have checked the return type already");
152            }
153        }
154
155        {
156            // `#[track_caller]` affects the ABI of a function (by adding a location argument),
157            // so a `track_caller` can only tail call other `track_caller` functions.
158            //
159            // The issue is however that we can't know if a function is `track_caller` or not at
160            // this point (THIR can be polymorphic, we may have an unresolved trait function).
161            // We could only allow functions that we *can* resolve and *are* `track_caller`,
162            // but that would turn changing `track_caller`-ness into a breaking change,
163            // which is probably undesirable.
164            //
165            // Also note that we don't check callee's `track_caller`-ness at all, mostly for the
166            // reasons above, but also because we can always tailcall the shim we'd generate for
167            // coercing the function to an `fn()` pointer. (although in that case the tailcall is
168            // basically useless -- the shim calls the actual function, so tailcalling the shim is
169            // equivalent to calling the function)
170            let caller_needs_location = self.needs_location(self.caller_ty);
171
172            if caller_needs_location {
173                self.report_track_caller_caller(expr.span);
174            }
175        }
176
177        if caller_sig.c_variadic {
178            self.report_c_variadic_caller(expr.span);
179        }
180
181        if callee_sig.c_variadic {
182            self.report_c_variadic_callee(expr.span);
183        }
184    }
185
186    /// Returns true if function of type `ty` needs location argument
187    /// (i.e. if a function is marked as `#[track_caller]`).
188    ///
189    /// Panics if the function's instance can't be immediately resolved.
190    fn needs_location(&self, ty: Ty<'tcx>) -> bool {
191        if let &ty::FnDef(did, substs) = ty.kind() {
192            let instance =
193                ty::Instance::expect_resolve(self.tcx, self.typing_env, did, substs, DUMMY_SP);
194
195            instance.def.requires_caller_location(self.tcx)
196        } else {
197            false
198        }
199    }
200
201    fn report_in_closure(&mut self, expr: &Expr<'_>) {
202        let err = self.tcx.dcx().span_err(expr.span, "`become` is not allowed in closures");
203        self.found_errors = Err(err);
204    }
205
206    fn report_builtin_op(&mut self, value: &Expr<'_>, expr: &Expr<'_>) {
207        let err = self
208            .tcx
209            .dcx()
210            .struct_span_err(value.span, "`become` does not support operators")
211            .with_note("using `become` on a builtin operator is not useful")
212            .with_span_suggestion(
213                value.span.until(expr.span),
214                "try using `return` instead",
215                "return ",
216                Applicability::MachineApplicable,
217            )
218            .emit();
219        self.found_errors = Err(err);
220    }
221
222    fn report_op(&mut self, fun_ty: Ty<'_>, args: &[ExprId], fn_span: Span, expr: &Expr<'_>) {
223        let mut err =
224            self.tcx.dcx().struct_span_err(fn_span, "`become` does not support operators");
225
226        if let &ty::FnDef(did, _substs) = fun_ty.kind()
227            && let parent = self.tcx.parent(did)
228            && matches!(self.tcx.def_kind(parent), DefKind::Trait)
229            && let Some(method) = op_trait_as_method_name(self.tcx, parent)
230        {
231            match args {
232                &[arg] => {
233                    let arg = &self.thir[arg];
234
235                    err.multipart_suggestion(
236                        "try using the method directly",
237                        vec![
238                            (fn_span.shrink_to_lo().until(arg.span), "(".to_owned()),
239                            (arg.span.shrink_to_hi(), format!(").{method}()")),
240                        ],
241                        Applicability::MaybeIncorrect,
242                    );
243                }
244                &[lhs, rhs] => {
245                    let lhs = &self.thir[lhs];
246                    let rhs = &self.thir[rhs];
247
248                    err.multipart_suggestion(
249                        "try using the method directly",
250                        vec![
251                            (lhs.span.shrink_to_lo(), format!("(")),
252                            (lhs.span.between(rhs.span), format!(").{method}(")),
253                            (rhs.span.between(expr.span.shrink_to_hi()), ")".to_owned()),
254                        ],
255                        Applicability::MaybeIncorrect,
256                    );
257                }
258                _ => span_bug!(expr.span, "operator with more than 2 args? {args:?}"),
259            }
260        }
261
262        self.found_errors = Err(err.emit());
263    }
264
265    fn report_non_call(&mut self, value: &Expr<'_>, expr: &Expr<'_>) {
266        let err = self
267            .tcx
268            .dcx()
269            .struct_span_err(value.span, "`become` requires a function call")
270            .with_span_note(value.span, "not a function call")
271            .with_span_suggestion(
272                value.span.until(expr.span),
273                "try using `return` instead",
274                "return ",
275                Applicability::MaybeIncorrect,
276            )
277            .emit();
278        self.found_errors = Err(err);
279    }
280
281    fn report_calling_closure(&mut self, fun: &Expr<'_>, tupled_args: Ty<'_>, expr: &Expr<'_>) {
282        let underscored_args = match tupled_args.kind() {
283            ty::Tuple(tys) if tys.is_empty() => "".to_owned(),
284            ty::Tuple(tys) => std::iter::repeat("_, ").take(tys.len() - 1).chain(["_"]).collect(),
285            _ => "_".to_owned(),
286        };
287
288        let err = self
289            .tcx
290            .dcx()
291            .struct_span_err(expr.span, "tail calling closures directly is not allowed")
292            .with_multipart_suggestion(
293                "try casting the closure to a function pointer type",
294                vec![
295                    (fun.span.shrink_to_lo(), "(".to_owned()),
296                    (fun.span.shrink_to_hi(), format!(" as fn({underscored_args}) -> _)")),
297                ],
298                Applicability::MaybeIncorrect,
299            )
300            .emit();
301        self.found_errors = Err(err);
302    }
303
304    fn report_calling_intrinsic(&mut self, expr: &Expr<'_>) {
305        let err = self
306            .tcx
307            .dcx()
308            .struct_span_err(expr.span, "tail calling intrinsics is not allowed")
309            .emit();
310
311        self.found_errors = Err(err);
312    }
313
314    fn report_nonfn_callee(&mut self, call_sp: Span, fun_sp: Span, ty: Ty<'_>) {
315        let mut err = self
316            .tcx
317            .dcx()
318            .struct_span_err(
319                call_sp,
320                "tail calls can only be performed with function definitions or pointers",
321            )
322            .with_note(format!("callee has type `{ty}`"));
323
324        let mut ty = ty;
325        let mut refs = 0;
326        while ty.is_box() || ty.is_ref() {
327            ty = ty.builtin_deref(false).unwrap();
328            refs += 1;
329        }
330
331        if refs > 0 && ty.is_fn() {
332            let thing = if ty.is_fn_ptr() { "pointer" } else { "definition" };
333
334            let derefs =
335                std::iter::once('(').chain(std::iter::repeat_n('*', refs)).collect::<String>();
336
337            err.multipart_suggestion(
338                format!("consider dereferencing the expression to get a function {thing}"),
339                vec![(fun_sp.shrink_to_lo(), derefs), (fun_sp.shrink_to_hi(), ")".to_owned())],
340                Applicability::MachineApplicable,
341            );
342        }
343
344        let err = err.emit();
345        self.found_errors = Err(err);
346    }
347
348    fn report_abi_mismatch(&mut self, sp: Span, caller_abi: ExternAbi, callee_abi: ExternAbi) {
349        let err = self
350            .tcx
351            .dcx()
352            .struct_span_err(sp, "mismatched function ABIs")
353            .with_note("`become` requires caller and callee to have the same ABI")
354            .with_note(format!("caller ABI is `{caller_abi}`, while callee ABI is `{callee_abi}`"))
355            .emit();
356        self.found_errors = Err(err);
357    }
358
359    fn report_arguments_mismatch(
360        &mut self,
361        sp: Span,
362        caller_sig: ty::FnSig<'_>,
363        callee_sig: ty::FnSig<'_>,
364    ) {
365        let err = self
366            .tcx
367            .dcx()
368            .struct_span_err(sp, "mismatched signatures")
369            .with_note("`become` requires caller and callee to have matching signatures")
370            .with_note(format!("caller signature: `{caller_sig}`"))
371            .with_note(format!("callee signature: `{callee_sig}`"))
372            .emit();
373        self.found_errors = Err(err);
374    }
375
376    fn report_track_caller_caller(&mut self, sp: Span) {
377        let err = self
378            .tcx
379            .dcx()
380            .struct_span_err(
381                sp,
382                "a function marked with `#[track_caller]` cannot perform a tail-call",
383            )
384            .emit();
385
386        self.found_errors = Err(err);
387    }
388
389    fn report_c_variadic_caller(&mut self, sp: Span) {
390        let err = self
391            .tcx
392            .dcx()
393            // FIXME(explicit_tail_calls): highlight the `...`
394            .struct_span_err(sp, "tail-calls are not allowed in c-variadic functions")
395            .emit();
396
397        self.found_errors = Err(err);
398    }
399
400    fn report_c_variadic_callee(&mut self, sp: Span) {
401        let err = self
402            .tcx
403            .dcx()
404            // FIXME(explicit_tail_calls): highlight the function or something...
405            .struct_span_err(sp, "c-variadic functions can't be tail-called")
406            .emit();
407
408        self.found_errors = Err(err);
409    }
410}
411
412impl<'a, 'tcx> Visitor<'a, 'tcx> for TailCallCkVisitor<'a, 'tcx> {
413    fn thir(&self) -> &'a Thir<'tcx> {
414        &self.thir
415    }
416
417    fn visit_expr(&mut self, expr: &'a Expr<'tcx>) {
418        ensure_sufficient_stack(|| {
419            if let ExprKind::Become { value } = expr.kind {
420                let call = &self.thir[value];
421                self.check_tail_call(call, expr);
422            }
423
424            visit::walk_expr(self, expr);
425        });
426    }
427}
428
429fn op_trait_as_method_name(tcx: TyCtxt<'_>, trait_did: DefId) -> Option<&'static str> {
430    let m = match tcx.as_lang_item(trait_did)? {
431        LangItem::Add => "add",
432        LangItem::Sub => "sub",
433        LangItem::Mul => "mul",
434        LangItem::Div => "div",
435        LangItem::Rem => "rem",
436        LangItem::Neg => "neg",
437        LangItem::Not => "not",
438        LangItem::BitXor => "bitxor",
439        LangItem::BitAnd => "bitand",
440        LangItem::BitOr => "bitor",
441        LangItem::Shl => "shl",
442        LangItem::Shr => "shr",
443        LangItem::AddAssign => "add_assign",
444        LangItem::SubAssign => "sub_assign",
445        LangItem::MulAssign => "mul_assign",
446        LangItem::DivAssign => "div_assign",
447        LangItem::RemAssign => "rem_assign",
448        LangItem::BitXorAssign => "bitxor_assign",
449        LangItem::BitAndAssign => "bitand_assign",
450        LangItem::BitOrAssign => "bitor_assign",
451        LangItem::ShlAssign => "shl_assign",
452        LangItem::ShrAssign => "shr_assign",
453        LangItem::Index => "index",
454        LangItem::IndexMut => "index_mut",
455        _ => return None,
456    };
457
458    Some(m)
459}