1use rustc_abi::ExternAbi;
2use rustc_data_structures::stack::ensure_sufficient_stack;
3use rustc_errors::Applicability;
4use rustc_hir::LangItem;
5use rustc_hir::def::DefKind;
6use rustc_middle::span_bug;
7use rustc_middle::thir::visit::{self, Visitor};
8use rustc_middle::thir::{BodyTy, Expr, ExprId, ExprKind, Thir};
9use rustc_middle::ty::{self, Ty, TyCtxt};
10use rustc_span::def_id::{DefId, LocalDefId};
11use rustc_span::{DUMMY_SP, ErrorGuaranteed, Span};
12
13pub(crate) fn check_tail_calls(tcx: TyCtxt<'_>, def: LocalDefId) -> Result<(), ErrorGuaranteed> {
14 let (thir, expr) = tcx.thir_body(def)?;
15 let thir = &thir.borrow();
16
17 if thir.exprs.is_empty() {
19 return Ok(());
20 }
21
22 let is_closure = matches!(tcx.def_kind(def), DefKind::Closure);
23 let caller_ty = tcx.type_of(def).skip_binder();
24
25 let mut visitor = TailCallCkVisitor {
26 tcx,
27 thir,
28 found_errors: Ok(()),
29 typing_env: ty::TypingEnv::non_body_analysis(tcx, def),
31 is_closure,
32 caller_ty,
33 };
34
35 visitor.visit_expr(&thir[expr]);
36
37 visitor.found_errors
38}
39
40struct TailCallCkVisitor<'a, 'tcx> {
41 tcx: TyCtxt<'tcx>,
42 thir: &'a Thir<'tcx>,
43 typing_env: ty::TypingEnv<'tcx>,
44 is_closure: bool,
46 found_errors: Result<(), ErrorGuaranteed>,
49 caller_ty: Ty<'tcx>,
51}
52
53impl<'tcx> TailCallCkVisitor<'_, 'tcx> {
54 fn check_tail_call(&mut self, call: &Expr<'_>, expr: &Expr<'_>) {
55 if self.is_closure {
56 self.report_in_closure(expr);
57 return;
58 }
59
60 let BodyTy::Fn(caller_sig) = self.thir.body_type else {
61 span_bug!(
62 call.span,
63 "`become` outside of functions should have been disallowed by hir_typeck"
64 )
65 };
66 let caller_sig = self.tcx.erase_regions(caller_sig);
70
71 let ExprKind::Scope { value, .. } = call.kind else {
72 span_bug!(call.span, "expected scope, found: {call:?}")
73 };
74 let value = &self.thir[value];
75
76 if matches!(
77 value.kind,
78 ExprKind::Binary { .. }
79 | ExprKind::Unary { .. }
80 | ExprKind::AssignOp { .. }
81 | ExprKind::Index { .. }
82 ) {
83 self.report_builtin_op(call, expr);
84 return;
85 }
86
87 let ExprKind::Call { ty, fun, ref args, from_hir_call, fn_span } = value.kind else {
88 self.report_non_call(value, expr);
89 return;
90 };
91
92 if !from_hir_call {
93 self.report_op(ty, args, fn_span, expr);
94 }
95
96 if let &ty::FnDef(did, args) = ty.kind() {
97 let parent = self.tcx.parent(did);
101 if self.tcx.fn_trait_kind_from_def_id(parent).is_some()
102 && let Some(this) = args.first()
103 && let Some(this) = this.as_type()
104 {
105 if this.is_closure() {
106 self.report_calling_closure(&self.thir[fun], args[1].as_type().unwrap(), expr);
107 } else {
108 self.report_nonfn_callee(fn_span, self.thir[fun].span, this);
110 }
111
112 return;
115 };
116
117 if self.tcx.intrinsic(did).is_some() {
118 self.report_calling_intrinsic(expr);
119 }
120 }
121
122 let (ty::FnDef(..) | ty::FnPtr(..)) = ty.kind() else {
123 self.report_nonfn_callee(fn_span, self.thir[fun].span, ty);
124
125 return;
127 };
128
129 let callee_sig =
131 self.tcx.normalize_erasing_late_bound_regions(self.typing_env, ty.fn_sig(self.tcx));
132
133 if caller_sig.abi != callee_sig.abi {
134 self.report_abi_mismatch(expr.span, caller_sig.abi, callee_sig.abi);
135 }
136
137 if caller_sig.inputs_and_output != callee_sig.inputs_and_output {
138 if caller_sig.inputs() != callee_sig.inputs() {
139 self.report_arguments_mismatch(expr.span, caller_sig, callee_sig);
140 }
141
142 if caller_sig.output() != callee_sig.output() {
151 span_bug!(expr.span, "hir typeck should have checked the return type already");
152 }
153 }
154
155 {
156 let caller_needs_location = self.needs_location(self.caller_ty);
171
172 if caller_needs_location {
173 self.report_track_caller_caller(expr.span);
174 }
175 }
176
177 if caller_sig.c_variadic {
178 self.report_c_variadic_caller(expr.span);
179 }
180
181 if callee_sig.c_variadic {
182 self.report_c_variadic_callee(expr.span);
183 }
184 }
185
186 fn needs_location(&self, ty: Ty<'tcx>) -> bool {
191 if let &ty::FnDef(did, substs) = ty.kind() {
192 let instance =
193 ty::Instance::expect_resolve(self.tcx, self.typing_env, did, substs, DUMMY_SP);
194
195 instance.def.requires_caller_location(self.tcx)
196 } else {
197 false
198 }
199 }
200
201 fn report_in_closure(&mut self, expr: &Expr<'_>) {
202 let err = self.tcx.dcx().span_err(expr.span, "`become` is not allowed in closures");
203 self.found_errors = Err(err);
204 }
205
206 fn report_builtin_op(&mut self, value: &Expr<'_>, expr: &Expr<'_>) {
207 let err = self
208 .tcx
209 .dcx()
210 .struct_span_err(value.span, "`become` does not support operators")
211 .with_note("using `become` on a builtin operator is not useful")
212 .with_span_suggestion(
213 value.span.until(expr.span),
214 "try using `return` instead",
215 "return ",
216 Applicability::MachineApplicable,
217 )
218 .emit();
219 self.found_errors = Err(err);
220 }
221
222 fn report_op(&mut self, fun_ty: Ty<'_>, args: &[ExprId], fn_span: Span, expr: &Expr<'_>) {
223 let mut err =
224 self.tcx.dcx().struct_span_err(fn_span, "`become` does not support operators");
225
226 if let &ty::FnDef(did, _substs) = fun_ty.kind()
227 && let parent = self.tcx.parent(did)
228 && matches!(self.tcx.def_kind(parent), DefKind::Trait)
229 && let Some(method) = op_trait_as_method_name(self.tcx, parent)
230 {
231 match args {
232 &[arg] => {
233 let arg = &self.thir[arg];
234
235 err.multipart_suggestion(
236 "try using the method directly",
237 vec![
238 (fn_span.shrink_to_lo().until(arg.span), "(".to_owned()),
239 (arg.span.shrink_to_hi(), format!(").{method}()")),
240 ],
241 Applicability::MaybeIncorrect,
242 );
243 }
244 &[lhs, rhs] => {
245 let lhs = &self.thir[lhs];
246 let rhs = &self.thir[rhs];
247
248 err.multipart_suggestion(
249 "try using the method directly",
250 vec![
251 (lhs.span.shrink_to_lo(), format!("(")),
252 (lhs.span.between(rhs.span), format!(").{method}(")),
253 (rhs.span.between(expr.span.shrink_to_hi()), ")".to_owned()),
254 ],
255 Applicability::MaybeIncorrect,
256 );
257 }
258 _ => span_bug!(expr.span, "operator with more than 2 args? {args:?}"),
259 }
260 }
261
262 self.found_errors = Err(err.emit());
263 }
264
265 fn report_non_call(&mut self, value: &Expr<'_>, expr: &Expr<'_>) {
266 let err = self
267 .tcx
268 .dcx()
269 .struct_span_err(value.span, "`become` requires a function call")
270 .with_span_note(value.span, "not a function call")
271 .with_span_suggestion(
272 value.span.until(expr.span),
273 "try using `return` instead",
274 "return ",
275 Applicability::MaybeIncorrect,
276 )
277 .emit();
278 self.found_errors = Err(err);
279 }
280
281 fn report_calling_closure(&mut self, fun: &Expr<'_>, tupled_args: Ty<'_>, expr: &Expr<'_>) {
282 let underscored_args = match tupled_args.kind() {
283 ty::Tuple(tys) if tys.is_empty() => "".to_owned(),
284 ty::Tuple(tys) => std::iter::repeat("_, ").take(tys.len() - 1).chain(["_"]).collect(),
285 _ => "_".to_owned(),
286 };
287
288 let err = self
289 .tcx
290 .dcx()
291 .struct_span_err(expr.span, "tail calling closures directly is not allowed")
292 .with_multipart_suggestion(
293 "try casting the closure to a function pointer type",
294 vec![
295 (fun.span.shrink_to_lo(), "(".to_owned()),
296 (fun.span.shrink_to_hi(), format!(" as fn({underscored_args}) -> _)")),
297 ],
298 Applicability::MaybeIncorrect,
299 )
300 .emit();
301 self.found_errors = Err(err);
302 }
303
304 fn report_calling_intrinsic(&mut self, expr: &Expr<'_>) {
305 let err = self
306 .tcx
307 .dcx()
308 .struct_span_err(expr.span, "tail calling intrinsics is not allowed")
309 .emit();
310
311 self.found_errors = Err(err);
312 }
313
314 fn report_nonfn_callee(&mut self, call_sp: Span, fun_sp: Span, ty: Ty<'_>) {
315 let mut err = self
316 .tcx
317 .dcx()
318 .struct_span_err(
319 call_sp,
320 "tail calls can only be performed with function definitions or pointers",
321 )
322 .with_note(format!("callee has type `{ty}`"));
323
324 let mut ty = ty;
325 let mut refs = 0;
326 while ty.is_box() || ty.is_ref() {
327 ty = ty.builtin_deref(false).unwrap();
328 refs += 1;
329 }
330
331 if refs > 0 && ty.is_fn() {
332 let thing = if ty.is_fn_ptr() { "pointer" } else { "definition" };
333
334 let derefs =
335 std::iter::once('(').chain(std::iter::repeat_n('*', refs)).collect::<String>();
336
337 err.multipart_suggestion(
338 format!("consider dereferencing the expression to get a function {thing}"),
339 vec![(fun_sp.shrink_to_lo(), derefs), (fun_sp.shrink_to_hi(), ")".to_owned())],
340 Applicability::MachineApplicable,
341 );
342 }
343
344 let err = err.emit();
345 self.found_errors = Err(err);
346 }
347
348 fn report_abi_mismatch(&mut self, sp: Span, caller_abi: ExternAbi, callee_abi: ExternAbi) {
349 let err = self
350 .tcx
351 .dcx()
352 .struct_span_err(sp, "mismatched function ABIs")
353 .with_note("`become` requires caller and callee to have the same ABI")
354 .with_note(format!("caller ABI is `{caller_abi}`, while callee ABI is `{callee_abi}`"))
355 .emit();
356 self.found_errors = Err(err);
357 }
358
359 fn report_arguments_mismatch(
360 &mut self,
361 sp: Span,
362 caller_sig: ty::FnSig<'_>,
363 callee_sig: ty::FnSig<'_>,
364 ) {
365 let err = self
366 .tcx
367 .dcx()
368 .struct_span_err(sp, "mismatched signatures")
369 .with_note("`become` requires caller and callee to have matching signatures")
370 .with_note(format!("caller signature: `{caller_sig}`"))
371 .with_note(format!("callee signature: `{callee_sig}`"))
372 .emit();
373 self.found_errors = Err(err);
374 }
375
376 fn report_track_caller_caller(&mut self, sp: Span) {
377 let err = self
378 .tcx
379 .dcx()
380 .struct_span_err(
381 sp,
382 "a function marked with `#[track_caller]` cannot perform a tail-call",
383 )
384 .emit();
385
386 self.found_errors = Err(err);
387 }
388
389 fn report_c_variadic_caller(&mut self, sp: Span) {
390 let err = self
391 .tcx
392 .dcx()
393 .struct_span_err(sp, "tail-calls are not allowed in c-variadic functions")
395 .emit();
396
397 self.found_errors = Err(err);
398 }
399
400 fn report_c_variadic_callee(&mut self, sp: Span) {
401 let err = self
402 .tcx
403 .dcx()
404 .struct_span_err(sp, "c-variadic functions can't be tail-called")
406 .emit();
407
408 self.found_errors = Err(err);
409 }
410}
411
412impl<'a, 'tcx> Visitor<'a, 'tcx> for TailCallCkVisitor<'a, 'tcx> {
413 fn thir(&self) -> &'a Thir<'tcx> {
414 &self.thir
415 }
416
417 fn visit_expr(&mut self, expr: &'a Expr<'tcx>) {
418 ensure_sufficient_stack(|| {
419 if let ExprKind::Become { value } = expr.kind {
420 let call = &self.thir[value];
421 self.check_tail_call(call, expr);
422 }
423
424 visit::walk_expr(self, expr);
425 });
426 }
427}
428
429fn op_trait_as_method_name(tcx: TyCtxt<'_>, trait_did: DefId) -> Option<&'static str> {
430 let m = match tcx.as_lang_item(trait_did)? {
431 LangItem::Add => "add",
432 LangItem::Sub => "sub",
433 LangItem::Mul => "mul",
434 LangItem::Div => "div",
435 LangItem::Rem => "rem",
436 LangItem::Neg => "neg",
437 LangItem::Not => "not",
438 LangItem::BitXor => "bitxor",
439 LangItem::BitAnd => "bitand",
440 LangItem::BitOr => "bitor",
441 LangItem::Shl => "shl",
442 LangItem::Shr => "shr",
443 LangItem::AddAssign => "add_assign",
444 LangItem::SubAssign => "sub_assign",
445 LangItem::MulAssign => "mul_assign",
446 LangItem::DivAssign => "div_assign",
447 LangItem::RemAssign => "rem_assign",
448 LangItem::BitXorAssign => "bitxor_assign",
449 LangItem::BitAndAssign => "bitand_assign",
450 LangItem::BitOrAssign => "bitor_assign",
451 LangItem::ShlAssign => "shl_assign",
452 LangItem::ShrAssign => "shr_assign",
453 LangItem::Index => "index",
454 LangItem::IndexMut => "index_mut",
455 _ => return None,
456 };
457
458 Some(m)
459}