use either::Right;
use rustc_const_eval::const_eval::CheckAlignment;
use rustc_const_eval::ReportErrorExt;
use rustc_data_structures::fx::FxHashSet;
use rustc_hir::def::DefKind;
use rustc_index::bit_set::BitSet;
use rustc_index::{IndexSlice, IndexVec};
use rustc_middle::mir::visit::{
MutVisitor, MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor,
};
use rustc_middle::mir::*;
use rustc_middle::ty::layout::{LayoutError, LayoutOf, LayoutOfHelpers, TyAndLayout};
use rustc_middle::ty::{self, GenericArgs, Instance, ParamEnv, Ty, TyCtxt, TypeVisitableExt};
use rustc_span::{def_id::DefId, Span};
use rustc_target::abi::{self, Align, HasDataLayout, Size, TargetDataLayout};
use rustc_target::spec::abi::Abi as CallAbi;
use crate::dataflow_const_prop::Patch;
use crate::MirPass;
use rustc_const_eval::interpret::{
self, compile_time_machine, AllocId, ConstAllocation, FnArg, Frame, ImmTy, Immediate, InterpCx,
InterpResult, MemoryKind, OpTy, PlaceTy, Pointer, Scalar, StackPopCleanup,
};
const MAX_ALLOC_LIMIT: u64 = 1024;
pub(crate) macro throw_machine_stop_str($($tt:tt)*) {{
#[derive(Debug)]
struct Zst;
impl std::fmt::Display for Zst {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, $($tt)*)
}
}
impl rustc_middle::mir::interpret::MachineStopType for Zst {
fn diagnostic_message(&self) -> rustc_errors::DiagnosticMessage {
self.to_string().into()
}
fn add_args(
self: Box<Self>,
_: &mut dyn FnMut(std::borrow::Cow<'static, str>, rustc_errors::DiagnosticArgValue<'static>),
) {}
}
throw_machine_stop!(Zst)
}}
pub struct ConstProp;
impl<'tcx> MirPass<'tcx> for ConstProp {
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
sess.mir_opt_level() >= 2
}
#[instrument(skip(self, tcx), level = "debug")]
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
if body.source.promoted.is_some() {
return;
}
let def_id = body.source.def_id().expect_local();
let def_kind = tcx.def_kind(def_id);
let is_fn_like = def_kind.is_fn_like();
let is_assoc_const = def_kind == DefKind::AssocConst;
if !is_fn_like && !is_assoc_const {
trace!("ConstProp skipped for {:?}", def_id);
return;
}
let is_generator = def_kind == DefKind::Generator;
if is_generator {
trace!("ConstProp skipped for generator {:?}", def_id);
return;
}
trace!("ConstProp starting for {:?}", def_id);
let mut optimization_finder = ConstPropagator::new(body, tcx);
for &bb in body.basic_blocks.reverse_postorder() {
let data = &body.basic_blocks[bb];
optimization_finder.visit_basic_block_data(bb, data);
}
let mut patch = optimization_finder.patch;
patch.visit_body_preserves_cfg(body);
trace!("ConstProp done for {:?}", def_id);
}
}
pub struct ConstPropMachine<'mir, 'tcx> {
stack: Vec<Frame<'mir, 'tcx>>,
pub written_only_inside_own_block_locals: FxHashSet<Local>,
pub can_const_prop: IndexVec<Local, ConstPropMode>,
}
impl ConstPropMachine<'_, '_> {
pub fn new(can_const_prop: IndexVec<Local, ConstPropMode>) -> Self {
Self {
stack: Vec::new(),
written_only_inside_own_block_locals: Default::default(),
can_const_prop,
}
}
}
impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx> {
compile_time_machine!(<'mir, 'tcx>);
const PANIC_ON_ALLOC_FAIL: bool = true; const POST_MONO_CHECKS: bool = false; type MemoryKind = !;
#[inline(always)]
fn enforce_alignment(_ecx: &InterpCx<'mir, 'tcx, Self>) -> CheckAlignment {
CheckAlignment::No
}
#[inline(always)]
fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool {
false }
fn alignment_check_failed(
ecx: &InterpCx<'mir, 'tcx, Self>,
_has: Align,
_required: Align,
_check: CheckAlignment,
) -> InterpResult<'tcx, ()> {
span_bug!(
ecx.cur_span(),
"`alignment_check_failed` called when no alignment check requested"
)
}
fn load_mir(
_ecx: &InterpCx<'mir, 'tcx, Self>,
_instance: ty::InstanceDef<'tcx>,
) -> InterpResult<'tcx, &'tcx Body<'tcx>> {
throw_machine_stop_str!("calling functions isn't supported in ConstProp")
}
fn panic_nounwind(_ecx: &mut InterpCx<'mir, 'tcx, Self>, _msg: &str) -> InterpResult<'tcx> {
throw_machine_stop_str!("panicking isn't supported in ConstProp")
}
fn find_mir_or_eval_fn(
_ecx: &mut InterpCx<'mir, 'tcx, Self>,
_instance: ty::Instance<'tcx>,
_abi: CallAbi,
_args: &[FnArg<'tcx>],
_destination: &PlaceTy<'tcx>,
_target: Option<BasicBlock>,
_unwind: UnwindAction,
) -> InterpResult<'tcx, Option<(&'mir Body<'tcx>, ty::Instance<'tcx>)>> {
Ok(None)
}
fn call_intrinsic(
_ecx: &mut InterpCx<'mir, 'tcx, Self>,
_instance: ty::Instance<'tcx>,
_args: &[OpTy<'tcx>],
_destination: &PlaceTy<'tcx>,
_target: Option<BasicBlock>,
_unwind: UnwindAction,
) -> InterpResult<'tcx> {
throw_machine_stop_str!("calling intrinsics isn't supported in ConstProp")
}
fn assert_panic(
_ecx: &mut InterpCx<'mir, 'tcx, Self>,
_msg: &rustc_middle::mir::AssertMessage<'tcx>,
_unwind: rustc_middle::mir::UnwindAction,
) -> InterpResult<'tcx> {
bug!("panics terminators are not evaluated in ConstProp")
}
fn binary_ptr_op(
_ecx: &InterpCx<'mir, 'tcx, Self>,
_bin_op: BinOp,
_left: &ImmTy<'tcx>,
_right: &ImmTy<'tcx>,
) -> InterpResult<'tcx, (ImmTy<'tcx>, bool)> {
throw_machine_stop_str!("pointer arithmetic or comparisons aren't supported in ConstProp")
}
fn before_access_local_mut<'a>(
ecx: &'a mut InterpCx<'mir, 'tcx, Self>,
frame: usize,
local: Local,
) -> InterpResult<'tcx> {
assert_eq!(frame, 0);
match ecx.machine.can_const_prop[local] {
ConstPropMode::NoPropagation => {
throw_machine_stop_str!(
"tried to write to a local that is marked as not propagatable"
)
}
ConstPropMode::OnlyInsideOwnBlock => {
ecx.machine.written_only_inside_own_block_locals.insert(local);
}
ConstPropMode::FullConstProp => {}
}
Ok(())
}
fn before_access_global(
_tcx: TyCtxt<'tcx>,
_machine: &Self,
_alloc_id: AllocId,
alloc: ConstAllocation<'tcx>,
_static_def_id: Option<DefId>,
is_write: bool,
) -> InterpResult<'tcx> {
if is_write {
throw_machine_stop_str!("can't write to global");
}
if alloc.inner().mutability.is_mut() {
throw_machine_stop_str!("can't access mutable globals in ConstProp");
}
Ok(())
}
#[inline(always)]
fn expose_ptr(
_ecx: &mut InterpCx<'mir, 'tcx, Self>,
_ptr: Pointer<AllocId>,
) -> InterpResult<'tcx> {
throw_machine_stop_str!("exposing pointers isn't supported in ConstProp")
}
#[inline(always)]
fn init_frame_extra(
_ecx: &mut InterpCx<'mir, 'tcx, Self>,
frame: Frame<'mir, 'tcx>,
) -> InterpResult<'tcx, Frame<'mir, 'tcx>> {
Ok(frame)
}
#[inline(always)]
fn stack<'a>(
ecx: &'a InterpCx<'mir, 'tcx, Self>,
) -> &'a [Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>] {
&ecx.machine.stack
}
#[inline(always)]
fn stack_mut<'a>(
ecx: &'a mut InterpCx<'mir, 'tcx, Self>,
) -> &'a mut Vec<Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>> {
&mut ecx.machine.stack
}
}
struct ConstPropagator<'mir, 'tcx> {
ecx: InterpCx<'mir, 'tcx, ConstPropMachine<'mir, 'tcx>>,
tcx: TyCtxt<'tcx>,
param_env: ParamEnv<'tcx>,
local_decls: &'mir IndexSlice<Local, LocalDecl<'tcx>>,
patch: Patch<'tcx>,
}
impl<'tcx> LayoutOfHelpers<'tcx> for ConstPropagator<'_, 'tcx> {
type LayoutOfResult = Result<TyAndLayout<'tcx>, LayoutError<'tcx>>;
#[inline]
fn handle_layout_err(&self, err: LayoutError<'tcx>, _: Span, _: Ty<'tcx>) -> LayoutError<'tcx> {
err
}
}
impl HasDataLayout for ConstPropagator<'_, '_> {
#[inline]
fn data_layout(&self) -> &TargetDataLayout {
&self.tcx.data_layout
}
}
impl<'tcx> ty::layout::HasTyCtxt<'tcx> for ConstPropagator<'_, 'tcx> {
#[inline]
fn tcx(&self) -> TyCtxt<'tcx> {
self.tcx
}
}
impl<'tcx> ty::layout::HasParamEnv<'tcx> for ConstPropagator<'_, 'tcx> {
#[inline]
fn param_env(&self) -> ty::ParamEnv<'tcx> {
self.param_env
}
}
impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
fn new(body: &'mir Body<'tcx>, tcx: TyCtxt<'tcx>) -> ConstPropagator<'mir, 'tcx> {
let def_id = body.source.def_id();
let args = &GenericArgs::identity_for_item(tcx, def_id);
let param_env = tcx.param_env_reveal_all_normalized(def_id);
let can_const_prop = CanConstProp::check(tcx, param_env, body);
let mut ecx = InterpCx::new(
tcx,
tcx.def_span(def_id),
param_env,
ConstPropMachine::new(can_const_prop),
);
let ret_layout = ecx
.layout_of(body.bound_return_ty().instantiate(tcx, args))
.ok()
.filter(|ret_layout| {
ret_layout.is_sized() && ret_layout.size < Size::from_bytes(MAX_ALLOC_LIMIT)
})
.unwrap_or_else(|| ecx.layout_of(tcx.types.unit).unwrap());
let ret = ecx
.allocate(ret_layout, MemoryKind::Stack)
.expect("couldn't perform small allocation")
.into();
ecx.push_stack_frame(
Instance::new(def_id, args),
body,
&ret,
StackPopCleanup::Root { cleanup: false },
)
.expect("failed to push initial stack frame");
for local in body.local_decls.indices() {
ecx.frame_mut().locals[local].make_live_uninit();
}
let patch = Patch::new(tcx);
ConstPropagator { ecx, tcx, param_env, local_decls: &body.local_decls, patch }
}
fn get_const(&self, place: Place<'tcx>) -> Option<OpTy<'tcx>> {
let op = match self.ecx.eval_place_to_op(place, None) {
Ok(op) => {
if op
.as_mplace_or_imm()
.right()
.is_some_and(|imm| matches!(*imm, Immediate::Uninit))
{
return None;
}
op
}
Err(e) => {
trace!("get_const failed: {:?}", e.into_kind().debug());
return None;
}
};
Some(match self.ecx.read_immediate_raw(&op) {
Ok(Right(imm)) => imm.into(),
_ => op,
})
}
fn remove_const(ecx: &mut InterpCx<'mir, 'tcx, ConstPropMachine<'mir, 'tcx>>, local: Local) {
ecx.frame_mut().locals[local].make_live_uninit();
ecx.machine.written_only_inside_own_block_locals.remove(&local);
}
fn check_rvalue(&mut self, rvalue: &Rvalue<'tcx>) -> Option<()> {
match rvalue {
Rvalue::AddressOf(_, place) | Rvalue::Ref(_, _, place) => {
trace!("skipping AddressOf | Ref for {:?}", place);
Self::remove_const(&mut self.ecx, place.local);
return None;
}
Rvalue::ThreadLocalRef(def_id) => {
trace!("skipping ThreadLocalRef({:?})", def_id);
return None;
}
Rvalue::Aggregate(..)
| Rvalue::Use(..)
| Rvalue::CopyForDeref(..)
| Rvalue::Repeat(..)
| Rvalue::Len(..)
| Rvalue::Cast(..)
| Rvalue::ShallowInitBox(..)
| Rvalue::Discriminant(..)
| Rvalue::NullaryOp(..)
| Rvalue::UnaryOp(..)
| Rvalue::BinaryOp(..)
| Rvalue::CheckedBinaryOp(..) => {}
}
if rvalue.has_param() {
return None;
}
if !rvalue
.ty(&self.ecx.frame().body.local_decls, *self.ecx.tcx)
.is_sized(*self.ecx.tcx, self.param_env)
{
return None;
}
Some(())
}
fn eval_rvalue_with_identities(
&mut self,
rvalue: &Rvalue<'tcx>,
place: Place<'tcx>,
) -> Option<()> {
match rvalue {
Rvalue::BinaryOp(op, box (left, right))
| Rvalue::CheckedBinaryOp(op, box (left, right)) => {
let l = self.ecx.eval_operand(left, None).and_then(|x| self.ecx.read_immediate(&x));
let r =
self.ecx.eval_operand(right, None).and_then(|x| self.ecx.read_immediate(&x));
let const_arg = match (l, r) {
(Ok(x), Err(_)) | (Err(_), Ok(x)) => x, (Err(_), Err(_)) => return None, (Ok(_), Ok(_)) => return self.ecx.eval_rvalue_into_place(rvalue, place).ok(), };
if !matches!(const_arg.layout.abi, abi::Abi::Scalar(..)) {
return None;
}
let arg_value = const_arg.to_scalar().to_bits(const_arg.layout.size).ok()?;
let dest = self.ecx.eval_place(place).ok()?;
match op {
BinOp::BitAnd if arg_value == 0 => {
self.ecx.write_immediate(*const_arg, &dest).ok()
}
BinOp::BitOr
if arg_value == const_arg.layout.size.truncate(u128::MAX)
|| (const_arg.layout.ty.is_bool() && arg_value == 1) =>
{
self.ecx.write_immediate(*const_arg, &dest).ok()
}
BinOp::Mul if const_arg.layout.ty.is_integral() && arg_value == 0 => {
if let Rvalue::CheckedBinaryOp(_, _) = rvalue {
let val = Immediate::ScalarPair(
const_arg.to_scalar(),
Scalar::from_bool(false),
);
self.ecx.write_immediate(val, &dest).ok()
} else {
self.ecx.write_immediate(*const_arg, &dest).ok()
}
}
_ => None,
}
}
_ => self.ecx.eval_rvalue_into_place(rvalue, place).ok(),
}
}
fn replace_with_const(&mut self, place: Place<'tcx>) -> Option<Const<'tcx>> {
let value = self.get_const(place)?;
if !self.tcx.consider_optimizing(|| format!("ConstantPropagation - {value:?}")) {
return None;
}
trace!("replacing {:?} with {:?}", place, value);
let imm = self.ecx.read_immediate_raw(&value).ok()?;
let Right(imm) = imm else { return None };
match *imm {
Immediate::Scalar(scalar) if scalar.try_to_int().is_ok() => {
Some(Const::from_scalar(self.tcx, scalar, value.layout.ty))
}
Immediate::ScalarPair(l, r) if l.try_to_int().is_ok() && r.try_to_int().is_ok() => {
let alloc_id = self
.ecx
.intern_with_temp_alloc(value.layout, |ecx, dest| {
ecx.write_immediate(*imm, dest)
})
.ok()?;
Some(Const::Val(
ConstValue::Indirect { alloc_id, offset: Size::ZERO },
value.layout.ty,
))
}
_ => None,
}
}
fn ensure_not_propagated(&self, local: Local) {
if cfg!(debug_assertions) {
assert!(
self.get_const(local.into()).is_none()
|| self
.layout_of(self.local_decls[local].ty)
.map_or(true, |layout| layout.is_zst()),
"failed to remove values for `{local:?}`, value={:?}",
self.get_const(local.into()),
)
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum ConstPropMode {
FullConstProp,
OnlyInsideOwnBlock,
NoPropagation,
}
pub struct CanConstProp {
can_const_prop: IndexVec<Local, ConstPropMode>,
found_assignment: BitSet<Local>,
}
impl CanConstProp {
pub fn check<'tcx>(
tcx: TyCtxt<'tcx>,
param_env: ParamEnv<'tcx>,
body: &Body<'tcx>,
) -> IndexVec<Local, ConstPropMode> {
let mut cpv = CanConstProp {
can_const_prop: IndexVec::from_elem(ConstPropMode::FullConstProp, &body.local_decls),
found_assignment: BitSet::new_empty(body.local_decls.len()),
};
for (local, val) in cpv.can_const_prop.iter_enumerated_mut() {
let ty = body.local_decls[local].ty;
match tcx.layout_of(param_env.and(ty)) {
Ok(layout) if layout.size < Size::from_bytes(MAX_ALLOC_LIMIT) => {}
_ => {
*val = ConstPropMode::NoPropagation;
continue;
}
}
}
for arg in body.args_iter() {
cpv.found_assignment.insert(arg);
}
cpv.visit_body(&body);
cpv.can_const_prop
}
}
impl<'tcx> Visitor<'tcx> for CanConstProp {
fn visit_place(&mut self, place: &Place<'tcx>, mut context: PlaceContext, loc: Location) {
use rustc_middle::mir::visit::PlaceContext::*;
if place.projection.first() == Some(&PlaceElem::Deref) {
context = NonMutatingUse(NonMutatingUseContext::Copy);
}
self.visit_local(place.local, context, loc);
self.visit_projection(place.as_ref(), context, loc);
}
fn visit_local(&mut self, local: Local, context: PlaceContext, _: Location) {
use rustc_middle::mir::visit::PlaceContext::*;
match context {
| MutatingUse(MutatingUseContext::Call)
| MutatingUse(MutatingUseContext::AsmOutput)
| MutatingUse(MutatingUseContext::Deinit)
| MutatingUse(MutatingUseContext::Store)
| MutatingUse(MutatingUseContext::SetDiscriminant) => {
if !self.found_assignment.insert(local) {
match &mut self.can_const_prop[local] {
ConstPropMode::OnlyInsideOwnBlock => {}
ConstPropMode::NoPropagation => {}
other @ ConstPropMode::FullConstProp => {
trace!(
"local {:?} can't be propagated because of multiple assignments. Previous state: {:?}",
local, other,
);
*other = ConstPropMode::OnlyInsideOwnBlock;
}
}
}
}
NonMutatingUse(NonMutatingUseContext::Copy)
| NonMutatingUse(NonMutatingUseContext::Move)
| NonMutatingUse(NonMutatingUseContext::Inspect)
| NonMutatingUse(NonMutatingUseContext::PlaceMention)
| NonUse(_) => {}
MutatingUse(MutatingUseContext::Yield)
| MutatingUse(MutatingUseContext::Drop)
| MutatingUse(MutatingUseContext::Retag)
| NonMutatingUse(NonMutatingUseContext::SharedBorrow)
| NonMutatingUse(NonMutatingUseContext::ShallowBorrow)
| NonMutatingUse(NonMutatingUseContext::AddressOf)
| MutatingUse(MutatingUseContext::Borrow)
| MutatingUse(MutatingUseContext::AddressOf) => {
trace!("local {:?} can't be propagated because it's used: {:?}", local, context);
self.can_const_prop[local] = ConstPropMode::NoPropagation;
}
MutatingUse(MutatingUseContext::Projection)
| NonMutatingUse(NonMutatingUseContext::Projection) => bug!("visit_place should not pass {context:?} for {local:?}"),
}
}
}
impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> {
fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) {
self.super_operand(operand, location);
if let Some(place) = operand.place() && let Some(value) = self.replace_with_const(place) {
self.patch.before_effect.insert((location, place), value);
}
}
fn visit_projection_elem(
&mut self,
_: PlaceRef<'tcx>,
elem: PlaceElem<'tcx>,
_: PlaceContext,
location: Location,
) {
if let PlaceElem::Index(local) = elem
&& let Some(value) = self.replace_with_const(local.into())
{
self.patch.before_effect.insert((location, local.into()), value);
}
}
fn visit_assign(&mut self, place: &Place<'tcx>, rvalue: &Rvalue<'tcx>, location: Location) {
self.super_assign(place, rvalue, location);
let Some(()) = self.check_rvalue(rvalue) else { return };
match self.ecx.machine.can_const_prop[place.local] {
_ if place.is_indirect() => {}
ConstPropMode::NoPropagation => self.ensure_not_propagated(place.local),
ConstPropMode::OnlyInsideOwnBlock | ConstPropMode::FullConstProp => {
if let Some(()) = self.eval_rvalue_with_identities(rvalue, *place) {
if let Rvalue::Use(Operand::Constant(c)) = rvalue
&& let Const::Val(..) = c.const_
{
trace!("skipping replace of Rvalue::Use({:?} because it is already a const", c);
} else if let Some(operand) = self.replace_with_const(*place) {
self.patch.assignments.insert(location, operand);
}
} else {
trace!(
"propagation into {:?} failed.
Nuking the entire site from orbit, it's the only way to be sure",
place,
);
Self::remove_const(&mut self.ecx, place.local);
}
}
}
}
fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
trace!("visit_statement: {:?}", statement);
self.super_statement(statement, location);
match statement.kind {
StatementKind::SetDiscriminant { ref place, .. } => {
match self.ecx.machine.can_const_prop[place.local] {
_ if place.is_indirect() => {}
ConstPropMode::NoPropagation => self.ensure_not_propagated(place.local),
ConstPropMode::FullConstProp | ConstPropMode::OnlyInsideOwnBlock => {
if self.ecx.statement(statement).is_ok() {
trace!("propped discriminant into {:?}", place);
} else {
Self::remove_const(&mut self.ecx, place.local);
}
}
}
}
StatementKind::StorageLive(local) => {
Self::remove_const(&mut self.ecx, local);
}
_ => {}
}
}
fn visit_basic_block_data(&mut self, block: BasicBlock, data: &BasicBlockData<'tcx>) {
self.super_basic_block_data(block, data);
let mut written_only_inside_own_block_locals =
std::mem::take(&mut self.ecx.machine.written_only_inside_own_block_locals);
for local in written_only_inside_own_block_locals.drain() {
debug_assert_eq!(
self.ecx.machine.can_const_prop[local],
ConstPropMode::OnlyInsideOwnBlock
);
Self::remove_const(&mut self.ecx, local);
}
self.ecx.machine.written_only_inside_own_block_locals =
written_only_inside_own_block_locals;
if cfg!(debug_assertions) {
for (local, &mode) in self.ecx.machine.can_const_prop.iter_enumerated() {
match mode {
ConstPropMode::FullConstProp => {}
ConstPropMode::NoPropagation | ConstPropMode::OnlyInsideOwnBlock => {
self.ensure_not_propagated(local);
}
}
}
}
}
}