use rustc_ast::ast::{AttrKind, Attribute, IntTy, LitIntType, LitKind, StrStyle, UintTy};
use rustc_ast::token::CommentKind;
use rustc_ast::AttrStyle;
use rustc_hir::intravisit::FnKind;
use rustc_hir::{
Block, BlockCheckMode, Body, Closure, Destination, Expr, ExprKind, FieldDef, FnHeader, HirId, Impl, ImplItem,
ImplItemKind, IsAuto, Item, ItemKind, LoopSource, MatchSource, MutTy, Node, QPath, TraitItem, TraitItemKind, Ty,
TyKind, UnOp, UnsafeSource, Unsafety, Variant, VariantData, YieldSource,
};
use rustc_lint::{LateContext, LintContext};
use rustc_middle::ty::TyCtxt;
use rustc_session::Session;
use rustc_span::symbol::Ident;
use rustc_span::{Span, Symbol};
use rustc_target::spec::abi::Abi;
#[derive(Clone)]
pub enum Pat {
Str(&'static str),
OwnedStr(String),
MultiStr(&'static [&'static str]),
OwnedMultiStr(Vec<String>),
Sym(Symbol),
Num,
}
fn span_matches_pat(sess: &Session, span: Span, start_pat: Pat, end_pat: Pat) -> bool {
let pos = sess.source_map().lookup_byte_offset(span.lo());
let Some(ref src) = pos.sf.src else {
return false;
};
let end = span.hi() - pos.sf.start_pos;
src.get(pos.pos.0 as usize..end.0 as usize).map_or(false, |s| {
let start_str = s.trim_start_matches(|c: char| c.is_whitespace() || c == '(');
let end_str = s.trim_end_matches(|c: char| c.is_whitespace() || c == ')' || c == ',');
(match start_pat {
Pat::Str(text) => start_str.starts_with(text),
Pat::OwnedStr(text) => start_str.starts_with(&text),
Pat::MultiStr(texts) => texts.iter().any(|s| start_str.starts_with(s)),
Pat::OwnedMultiStr(texts) => texts.iter().any(|s| start_str.starts_with(s)),
Pat::Sym(sym) => start_str.starts_with(sym.as_str()),
Pat::Num => start_str.as_bytes().first().map_or(false, u8::is_ascii_digit),
} && match end_pat {
Pat::Str(text) => end_str.ends_with(text),
Pat::OwnedStr(text) => end_str.starts_with(&text),
Pat::MultiStr(texts) => texts.iter().any(|s| start_str.ends_with(s)),
Pat::OwnedMultiStr(texts) => texts.iter().any(|s| start_str.starts_with(s)),
Pat::Sym(sym) => end_str.ends_with(sym.as_str()),
Pat::Num => end_str.as_bytes().last().map_or(false, u8::is_ascii_hexdigit),
})
})
}
fn lit_search_pat(lit: &LitKind) -> (Pat, Pat) {
match lit {
LitKind::Str(_, StrStyle::Cooked) => (Pat::Str("\""), Pat::Str("\"")),
LitKind::Str(_, StrStyle::Raw(0)) => (Pat::Str("r"), Pat::Str("\"")),
LitKind::Str(_, StrStyle::Raw(_)) => (Pat::Str("r#"), Pat::Str("#")),
LitKind::ByteStr(_, StrStyle::Cooked) => (Pat::Str("b\""), Pat::Str("\"")),
LitKind::ByteStr(_, StrStyle::Raw(0)) => (Pat::Str("br\""), Pat::Str("\"")),
LitKind::ByteStr(_, StrStyle::Raw(_)) => (Pat::Str("br#\""), Pat::Str("#")),
LitKind::Byte(_) => (Pat::Str("b'"), Pat::Str("'")),
LitKind::Char(_) => (Pat::Str("'"), Pat::Str("'")),
LitKind::Int(_, LitIntType::Signed(IntTy::Isize)) => (Pat::Num, Pat::Str("isize")),
LitKind::Int(_, LitIntType::Unsigned(UintTy::Usize)) => (Pat::Num, Pat::Str("usize")),
LitKind::Int(..) => (Pat::Num, Pat::Num),
LitKind::Float(..) => (Pat::Num, Pat::Str("")),
LitKind::Bool(true) => (Pat::Str("true"), Pat::Str("true")),
LitKind::Bool(false) => (Pat::Str("false"), Pat::Str("false")),
_ => (Pat::Str(""), Pat::Str("")),
}
}
fn qpath_search_pat(path: &QPath<'_>) -> (Pat, Pat) {
match path {
QPath::Resolved(ty, path) => {
let start = if ty.is_some() {
Pat::Str("<")
} else {
path.segments
.first()
.map_or(Pat::Str(""), |seg| Pat::Sym(seg.ident.name))
};
let end = path.segments.last().map_or(Pat::Str(""), |seg| {
if seg.args.is_some() {
Pat::Str(">")
} else {
Pat::Sym(seg.ident.name)
}
});
(start, end)
},
QPath::TypeRelative(_, name) => (Pat::Str(""), Pat::Sym(name.ident.name)),
QPath::LangItem(..) => (Pat::Str(""), Pat::Str("")),
}
}
fn expr_search_pat(tcx: TyCtxt<'_>, e: &Expr<'_>) -> (Pat, Pat) {
match e.kind {
ExprKind::ConstBlock(_) => (Pat::Str("const"), Pat::Str("}")),
ExprKind::Tup([]) => (Pat::Str(")"), Pat::Str("(")),
ExprKind::Unary(UnOp::Deref, e) => (Pat::Str("*"), expr_search_pat(tcx, e).1),
ExprKind::Unary(UnOp::Not, e) => (Pat::Str("!"), expr_search_pat(tcx, e).1),
ExprKind::Unary(UnOp::Neg, e) => (Pat::Str("-"), expr_search_pat(tcx, e).1),
ExprKind::Lit(lit) => lit_search_pat(&lit.node),
ExprKind::Array(_) | ExprKind::Repeat(..) => (Pat::Str("["), Pat::Str("]")),
ExprKind::Call(e, []) | ExprKind::MethodCall(_, e, [], _) => (expr_search_pat(tcx, e).0, Pat::Str("(")),
ExprKind::Call(first, [.., last])
| ExprKind::MethodCall(_, first, [.., last], _)
| ExprKind::Binary(_, first, last)
| ExprKind::Tup([first, .., last])
| ExprKind::Assign(first, last, _)
| ExprKind::AssignOp(_, first, last) => (expr_search_pat(tcx, first).0, expr_search_pat(tcx, last).1),
ExprKind::Tup([e]) | ExprKind::DropTemps(e) => expr_search_pat(tcx, e),
ExprKind::Cast(e, _) | ExprKind::Type(e, _) => (expr_search_pat(tcx, e).0, Pat::Str("")),
ExprKind::Let(let_expr) => (Pat::Str("let"), expr_search_pat(tcx, let_expr.init).1),
ExprKind::If(..) => (Pat::Str("if"), Pat::Str("}")),
ExprKind::Loop(_, Some(_), _, _) | ExprKind::Block(_, Some(_)) => (Pat::Str("'"), Pat::Str("}")),
ExprKind::Loop(_, None, LoopSource::Loop, _) => (Pat::Str("loop"), Pat::Str("}")),
ExprKind::Loop(_, None, LoopSource::While, _) => (Pat::Str("while"), Pat::Str("}")),
ExprKind::Loop(_, None, LoopSource::ForLoop, _) | ExprKind::Match(_, _, MatchSource::ForLoopDesugar) => {
(Pat::Str("for"), Pat::Str("}"))
},
ExprKind::Match(_, _, MatchSource::Normal) => (Pat::Str("match"), Pat::Str("}")),
ExprKind::Match(e, _, MatchSource::TryDesugar(_)) => (expr_search_pat(tcx, e).0, Pat::Str("?")),
ExprKind::Match(e, _, MatchSource::AwaitDesugar) | ExprKind::Yield(e, YieldSource::Await { .. }) => {
(expr_search_pat(tcx, e).0, Pat::Str("await"))
},
ExprKind::Closure(&Closure { body, .. }) => (Pat::Str(""), expr_search_pat(tcx, tcx.hir().body(body).value).1),
ExprKind::Block(
Block {
rules: BlockCheckMode::UnsafeBlock(UnsafeSource::UserProvided),
..
},
None,
) => (Pat::Str("unsafe"), Pat::Str("}")),
ExprKind::Block(_, None) => (Pat::Str("{"), Pat::Str("}")),
ExprKind::Field(e, name) => (expr_search_pat(tcx, e).0, Pat::Sym(name.name)),
ExprKind::Index(e, _, _) => (expr_search_pat(tcx, e).0, Pat::Str("]")),
ExprKind::Path(ref path) => qpath_search_pat(path),
ExprKind::AddrOf(_, _, e) => (Pat::Str("&"), expr_search_pat(tcx, e).1),
ExprKind::Break(Destination { label: None, .. }, None) => (Pat::Str("break"), Pat::Str("break")),
ExprKind::Break(Destination { label: Some(name), .. }, None) => (Pat::Str("break"), Pat::Sym(name.ident.name)),
ExprKind::Break(_, Some(e)) => (Pat::Str("break"), expr_search_pat(tcx, e).1),
ExprKind::Continue(Destination { label: None, .. }) => (Pat::Str("continue"), Pat::Str("continue")),
ExprKind::Continue(Destination { label: Some(name), .. }) => (Pat::Str("continue"), Pat::Sym(name.ident.name)),
ExprKind::Ret(None) => (Pat::Str("return"), Pat::Str("return")),
ExprKind::Ret(Some(e)) => (Pat::Str("return"), expr_search_pat(tcx, e).1),
ExprKind::Struct(path, _, _) => (qpath_search_pat(path).0, Pat::Str("}")),
ExprKind::Yield(e, YieldSource::Yield) => (Pat::Str("yield"), expr_search_pat(tcx, e).1),
_ => (Pat::Str(""), Pat::Str("")),
}
}
fn fn_header_search_pat(header: FnHeader) -> Pat {
if header.is_async() {
Pat::Str("async")
} else if header.is_const() {
Pat::Str("const")
} else if header.is_unsafe() {
Pat::Str("unsafe")
} else if header.abi != Abi::Rust {
Pat::Str("extern")
} else {
Pat::MultiStr(&["fn", "extern"])
}
}
fn item_search_pat(item: &Item<'_>) -> (Pat, Pat) {
let (start_pat, end_pat) = match &item.kind {
ItemKind::ExternCrate(_) => (Pat::Str("extern"), Pat::Str(";")),
ItemKind::Static(..) => (Pat::Str("static"), Pat::Str(";")),
ItemKind::Const(..) => (Pat::Str("const"), Pat::Str(";")),
ItemKind::Fn(sig, ..) => (fn_header_search_pat(sig.header), Pat::Str("")),
ItemKind::ForeignMod { .. } => (Pat::Str("extern"), Pat::Str("}")),
ItemKind::TyAlias(..) | ItemKind::OpaqueTy(_) => (Pat::Str("type"), Pat::Str(";")),
ItemKind::Enum(..) => (Pat::Str("enum"), Pat::Str("}")),
ItemKind::Struct(VariantData::Struct(..), _) => (Pat::Str("struct"), Pat::Str("}")),
ItemKind::Struct(..) => (Pat::Str("struct"), Pat::Str(";")),
ItemKind::Union(..) => (Pat::Str("union"), Pat::Str("}")),
ItemKind::Trait(_, Unsafety::Unsafe, ..)
| ItemKind::Impl(Impl {
unsafety: Unsafety::Unsafe,
..
}) => (Pat::Str("unsafe"), Pat::Str("}")),
ItemKind::Trait(IsAuto::Yes, ..) => (Pat::Str("auto"), Pat::Str("}")),
ItemKind::Trait(..) => (Pat::Str("trait"), Pat::Str("}")),
ItemKind::Impl(_) => (Pat::Str("impl"), Pat::Str("}")),
_ => return (Pat::Str(""), Pat::Str("")),
};
if item.vis_span.is_empty() {
(start_pat, end_pat)
} else {
(Pat::Str("pub"), end_pat)
}
}
fn trait_item_search_pat(item: &TraitItem<'_>) -> (Pat, Pat) {
match &item.kind {
TraitItemKind::Const(..) => (Pat::Str("const"), Pat::Str(";")),
TraitItemKind::Type(..) => (Pat::Str("type"), Pat::Str(";")),
TraitItemKind::Fn(sig, ..) => (fn_header_search_pat(sig.header), Pat::Str("")),
}
}
fn impl_item_search_pat(item: &ImplItem<'_>) -> (Pat, Pat) {
let (start_pat, end_pat) = match &item.kind {
ImplItemKind::Const(..) => (Pat::Str("const"), Pat::Str(";")),
ImplItemKind::Type(..) => (Pat::Str("type"), Pat::Str(";")),
ImplItemKind::Fn(sig, ..) => (fn_header_search_pat(sig.header), Pat::Str("")),
};
if item.vis_span.is_empty() {
(start_pat, end_pat)
} else {
(Pat::Str("pub"), end_pat)
}
}
fn field_def_search_pat(def: &FieldDef<'_>) -> (Pat, Pat) {
if def.vis_span.is_empty() {
if def.is_positional() {
(Pat::Str(""), Pat::Str(""))
} else {
(Pat::Sym(def.ident.name), Pat::Str(""))
}
} else {
(Pat::Str("pub"), Pat::Str(""))
}
}
fn variant_search_pat(v: &Variant<'_>) -> (Pat, Pat) {
match v.data {
VariantData::Struct(..) => (Pat::Sym(v.ident.name), Pat::Str("}")),
VariantData::Tuple(..) => (Pat::Sym(v.ident.name), Pat::Str("")),
VariantData::Unit(..) => (Pat::Sym(v.ident.name), Pat::Sym(v.ident.name)),
}
}
fn fn_kind_pat(tcx: TyCtxt<'_>, kind: &FnKind<'_>, body: &Body<'_>, hir_id: HirId) -> (Pat, Pat) {
let (start_pat, end_pat) = match kind {
FnKind::ItemFn(.., header) => (fn_header_search_pat(*header), Pat::Str("")),
FnKind::Method(.., sig) => (fn_header_search_pat(sig.header), Pat::Str("")),
FnKind::Closure => return (Pat::Str(""), expr_search_pat(tcx, body.value).1),
};
let start_pat = match tcx.hir().get(hir_id) {
Node::Item(Item { vis_span, .. }) | Node::ImplItem(ImplItem { vis_span, .. }) => {
if vis_span.is_empty() {
start_pat
} else {
Pat::Str("pub")
}
},
Node::TraitItem(_) => start_pat,
_ => Pat::Str(""),
};
(start_pat, end_pat)
}
fn attr_search_pat(attr: &Attribute) -> (Pat, Pat) {
match attr.kind {
AttrKind::Normal(..) => {
let mut pat = if matches!(attr.style, AttrStyle::Outer) {
(Pat::Str("#["), Pat::Str("]"))
} else {
(Pat::Str("#!["), Pat::Str("]"))
};
if let Some(ident) = attr.ident() && let Pat::Str(old_pat) = pat.0 {
pat.0 = Pat::OwnedMultiStr(vec![ident.to_string(), old_pat.to_owned()]);
pat.1 = Pat::Str("");
}
pat
},
AttrKind::DocComment(_kind @ CommentKind::Line, ..) => {
if matches!(attr.style, AttrStyle::Outer) {
(Pat::Str("///"), Pat::Str(""))
} else {
(Pat::Str("//!"), Pat::Str(""))
}
},
AttrKind::DocComment(_kind @ CommentKind::Block, ..) => {
if matches!(attr.style, AttrStyle::Outer) {
(Pat::Str("/**"), Pat::Str("*/"))
} else {
(Pat::Str("/*!"), Pat::Str("*/"))
}
},
}
}
fn ty_search_pat(ty: &Ty<'_>) -> (Pat, Pat) {
match ty.kind {
TyKind::Slice(..) | TyKind::Array(..) => (Pat::Str("["), Pat::Str("]")),
TyKind::Ptr(MutTy { mutbl, ty }) => (
if mutbl.is_mut() {
Pat::Str("*const")
} else {
Pat::Str("*mut")
},
ty_search_pat(ty).1,
),
TyKind::Ref(_, MutTy { ty, .. }) => (Pat::Str("&"), ty_search_pat(ty).1),
TyKind::BareFn(bare_fn) => (
Pat::OwnedStr(format!("{}{} fn", bare_fn.unsafety.prefix_str(), bare_fn.abi.name())),
ty_search_pat(ty).1,
),
TyKind::Never => (Pat::Str("!"), Pat::Str("")),
TyKind::Tup(..) => (Pat::Str("("), Pat::Str(")")),
TyKind::OpaqueDef(..) => (Pat::Str("impl"), Pat::Str("")),
TyKind::Path(qpath) => qpath_search_pat(&qpath),
_ => (Pat::Str(""), Pat::Str("")),
}
}
fn ident_search_pat(ident: Ident) -> (Pat, Pat) {
(Pat::OwnedStr(ident.name.as_str().to_owned()), Pat::Str(""))
}
pub trait WithSearchPat<'cx> {
type Context: LintContext;
fn search_pat(&self, cx: &Self::Context) -> (Pat, Pat);
fn span(&self) -> Span;
}
macro_rules! impl_with_search_pat {
($cx:ident: $ty:ident with $fn:ident $(($tcx:ident))?) => {
impl<'cx> WithSearchPat<'cx> for $ty<'cx> {
type Context = $cx<'cx>;
#[allow(unused_variables)]
fn search_pat(&self, cx: &Self::Context) -> (Pat, Pat) {
$(let $tcx = cx.tcx;)?
$fn($($tcx,)? self)
}
fn span(&self) -> Span {
self.span
}
}
};
}
impl_with_search_pat!(LateContext: Expr with expr_search_pat(tcx));
impl_with_search_pat!(LateContext: Item with item_search_pat);
impl_with_search_pat!(LateContext: TraitItem with trait_item_search_pat);
impl_with_search_pat!(LateContext: ImplItem with impl_item_search_pat);
impl_with_search_pat!(LateContext: FieldDef with field_def_search_pat);
impl_with_search_pat!(LateContext: Variant with variant_search_pat);
impl_with_search_pat!(LateContext: Ty with ty_search_pat);
impl<'cx> WithSearchPat<'cx> for (&FnKind<'cx>, &Body<'cx>, HirId, Span) {
type Context = LateContext<'cx>;
fn search_pat(&self, cx: &Self::Context) -> (Pat, Pat) {
fn_kind_pat(cx.tcx, self.0, self.1, self.2)
}
fn span(&self) -> Span {
self.3
}
}
impl<'cx> WithSearchPat<'cx> for &'cx Attribute {
type Context = LateContext<'cx>;
fn search_pat(&self, _cx: &Self::Context) -> (Pat, Pat) {
attr_search_pat(self)
}
fn span(&self) -> Span {
self.span
}
}
impl<'cx> WithSearchPat<'cx> for Ident {
type Context = LateContext<'cx>;
fn search_pat(&self, _cx: &Self::Context) -> (Pat, Pat) {
ident_search_pat(*self)
}
fn span(&self) -> Span {
self.span
}
}
pub fn is_from_proc_macro<'cx, T: WithSearchPat<'cx>>(cx: &T::Context, item: &T) -> bool {
let (start_pat, end_pat) = item.search_pat(cx);
!span_matches_pat(cx.sess(), item.span(), start_pat, end_pat)
}
pub fn is_span_match(cx: &impl LintContext, span: Span) -> bool {
span_matches_pat(cx.sess(), span, Pat::Str("match"), Pat::Str("}"))
}
pub fn is_span_if(cx: &impl LintContext, span: Span) -> bool {
span_matches_pat(cx.sess(), span, Pat::Str("if"), Pat::Str("}"))
}