#![allow(clippy::needless_pass_by_value, clippy::if_not_else)]
extern crate proc_macro;
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span, TokenStream as TokenStream2, TokenTree};
use quote::{quote, quote_spanned};
use syn::parse::{Nothing, ParseStream, Parser};
use syn::punctuated::Punctuated;
use syn::{
parenthesized, parse_macro_input, token, Abi, Attribute, Data, DeriveInput, Error, Expr, Field,
Generics, Path, Result, Token, Type, Visibility,
};
#[proc_macro_derive(RefCast, attributes(trivial))]
pub fn derive_ref_cast(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
expand_ref_cast(input)
.unwrap_or_else(Error::into_compile_error)
.into()
}
#[proc_macro_derive(RefCastCustom, attributes(trivial))]
pub fn derive_ref_cast_custom(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
expand_ref_cast_custom(input)
.unwrap_or_else(Error::into_compile_error)
.into()
}
#[proc_macro_attribute]
pub fn ref_cast_custom(args: TokenStream, input: TokenStream) -> TokenStream {
let input = TokenStream2::from(input);
let expanded = match (|input: ParseStream| {
let attrs = input.call(Attribute::parse_outer)?;
let vis: Visibility = input.parse()?;
let constness: Option<Token![const]> = input.parse()?;
let asyncness: Option<Token![async]> = input.parse()?;
let unsafety: Option<Token![unsafe]> = input.parse()?;
let abi: Option<Abi> = input.parse()?;
let fn_token: Token![fn] = input.parse()?;
let ident: Ident = input.parse()?;
let mut generics: Generics = input.parse()?;
let content;
let paren_token = parenthesized!(content in input);
let arg: Ident = content.parse()?;
let colon_token: Token![:] = content.parse()?;
let from_type: Type = content.parse()?;
let _trailing_comma: Option<Token![,]> = content.parse()?;
if !content.is_empty() {
let rest: TokenStream2 = content.parse()?;
return Err(Error::new_spanned(
rest,
"ref_cast_custom function is required to have a single argument",
));
}
let arrow_token: Token![->] = input.parse()?;
let to_type: Type = input.parse()?;
generics.where_clause = input.parse()?;
let semi_token: Token![;] = input.parse()?;
let _: Nothing = syn::parse::<Nothing>(args)?;
Ok(Function {
attrs,
vis,
constness,
asyncness,
unsafety,
abi,
fn_token,
ident,
generics,
paren_token,
arg,
colon_token,
from_type,
arrow_token,
to_type,
semi_token,
})
})
.parse2(input.clone())
{
Ok(function) => expand_function_body(function),
Err(parse_error) => {
let compile_error = parse_error.to_compile_error();
quote!(#compile_error #input)
}
};
TokenStream::from(expanded)
}
struct Function {
attrs: Vec<Attribute>,
vis: Visibility,
constness: Option<Token![const]>,
asyncness: Option<Token![async]>,
unsafety: Option<Token![unsafe]>,
abi: Option<Abi>,
fn_token: Token![fn],
ident: Ident,
generics: Generics,
paren_token: token::Paren,
arg: Ident,
colon_token: Token![:],
from_type: Type,
arrow_token: Token![->],
to_type: Type,
semi_token: Token![;],
}
fn expand_ref_cast(input: DeriveInput) -> Result<TokenStream2> {
check_repr(&input)?;
let name = &input.ident;
let name_str = name.to_string();
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let fields = fields(&input)?;
let from = only_field_ty(fields)?;
let trivial = trivial_fields(fields)?;
let assert_trivial_fields = if !trivial.is_empty() {
Some(quote! {
if false {
#(
::ref_cast::__private::assert_trivial::<#trivial>();
)*
}
})
} else {
None
};
Ok(quote! {
impl #impl_generics ::ref_cast::RefCast for #name #ty_generics #where_clause {
type From = #from;
#[inline]
fn ref_cast(_from: &Self::From) -> &Self {
#assert_trivial_fields
#[cfg(debug_assertions)]
{
#[allow(unused_imports)]
use ::ref_cast::__private::LayoutUnsized;
::ref_cast::__private::assert_layout::<Self, Self::From>(
#name_str,
::ref_cast::__private::Layout::<Self>::SIZE,
::ref_cast::__private::Layout::<Self::From>::SIZE,
::ref_cast::__private::Layout::<Self>::ALIGN,
::ref_cast::__private::Layout::<Self::From>::ALIGN,
);
}
unsafe {
&*(_from as *const Self::From as *const Self)
}
}
#[inline]
fn ref_cast_mut(_from: &mut Self::From) -> &mut Self {
#[cfg(debug_assertions)]
{
#[allow(unused_imports)]
use ::ref_cast::__private::LayoutUnsized;
::ref_cast::__private::assert_layout::<Self, Self::From>(
#name_str,
::ref_cast::__private::Layout::<Self>::SIZE,
::ref_cast::__private::Layout::<Self::From>::SIZE,
::ref_cast::__private::Layout::<Self>::ALIGN,
::ref_cast::__private::Layout::<Self::From>::ALIGN,
);
}
unsafe {
&mut *(_from as *mut Self::From as *mut Self)
}
}
}
})
}
fn expand_ref_cast_custom(input: DeriveInput) -> Result<TokenStream2> {
check_repr(&input)?;
let vis = &input.vis;
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let fields = fields(&input)?;
let from = only_field_ty(fields)?;
let trivial = trivial_fields(fields)?;
let assert_trivial_fields = if !trivial.is_empty() {
Some(quote! {
fn __static_assert() {
if false {
#(
::ref_cast::__private::assert_trivial::<#trivial>();
)*
}
}
})
} else {
None
};
Ok(quote! {
const _: () = {
#[non_exhaustive]
#vis struct RefCastCurrentCrate {}
unsafe impl #impl_generics ::ref_cast::__private::RefCastCustom<#from> for #name #ty_generics #where_clause {
type CurrentCrate = RefCastCurrentCrate;
#assert_trivial_fields
}
};
})
}
fn expand_function_body(function: Function) -> TokenStream2 {
let Function {
attrs,
vis,
constness,
asyncness,
unsafety,
abi,
fn_token,
ident,
generics,
paren_token,
arg,
colon_token,
from_type,
arrow_token,
to_type,
semi_token,
} = function;
let args = quote_spanned! {paren_token.span=>
(#arg #colon_token #from_type)
};
let allow_unused_unsafe = if unsafety.is_some() {
Some(quote!(#[allow(unused_unsafe)]))
} else {
None
};
let mut inline_attr = Some(quote!(#[inline]));
for attr in &attrs {
if attr.path().is_ident("inline") {
inline_attr = None;
break;
}
}
quote_spanned! {semi_token.span=>
#(#attrs)*
#inline_attr
#vis #constness #asyncness #unsafety #abi
#fn_token #ident #generics #args #arrow_token #to_type {
let _ = || {
::ref_cast::__private::ref_cast_custom::<#from_type, #to_type>(#arg);
};
let _ = ::ref_cast::__private::CurrentCrate::<#from_type, #to_type> {};
#allow_unused_unsafe unsafe {
::ref_cast::__private::transmute::<#from_type, #to_type>(#arg)
}
}
}
}
fn check_repr(input: &DeriveInput) -> Result<()> {
let mut has_repr = false;
let mut errors = None;
let mut push_error = |error| match &mut errors {
Some(errors) => Error::combine(errors, error),
None => errors = Some(error),
};
for attr in &input.attrs {
if attr.path().is_ident("repr") {
if let Err(error) = attr.parse_args_with(|input: ParseStream| {
while !input.is_empty() {
let path = input.call(Path::parse_mod_style)?;
if path.is_ident("transparent") || path.is_ident("C") {
has_repr = true;
} else if path.is_ident("packed") {
} else {
let meta_item_span = if input.peek(token::Paren) {
let group: TokenTree = input.parse()?;
quote!(#path #group)
} else if input.peek(Token![=]) {
let eq_token: Token![=] = input.parse()?;
let value: Expr = input.parse()?;
quote!(#path #eq_token #value)
} else {
quote!(#path)
};
let msg = if path.is_ident("align") {
"aligned repr on struct that implements RefCast is not supported"
} else {
"unrecognized repr on struct that implements RefCast"
};
push_error(Error::new_spanned(meta_item_span, msg));
}
if !input.is_empty() {
input.parse::<Token![,]>()?;
}
}
Ok(())
}) {
push_error(error);
}
}
}
if !has_repr {
let mut requires_repr = Error::new(
Span::call_site(),
"RefCast trait requires #[repr(transparent)]",
);
if let Some(errors) = errors {
requires_repr.combine(errors);
}
errors = Some(requires_repr);
}
match errors {
None => Ok(()),
Some(errors) => Err(errors),
}
}
type Fields = Punctuated<Field, Token![,]>;
fn fields(input: &DeriveInput) -> Result<&Fields> {
use syn::Fields;
match &input.data {
Data::Struct(data) => match &data.fields {
Fields::Named(fields) => Ok(&fields.named),
Fields::Unnamed(fields) => Ok(&fields.unnamed),
Fields::Unit => Err(Error::new(
Span::call_site(),
"RefCast does not support unit structs",
)),
},
Data::Enum(_) => Err(Error::new(
Span::call_site(),
"RefCast does not support enums",
)),
Data::Union(_) => Err(Error::new(
Span::call_site(),
"RefCast does not support unions",
)),
}
}
fn only_field_ty(fields: &Fields) -> Result<&Type> {
let is_trivial = decide_trivial(fields)?;
let mut only_field = None;
for field in fields {
if !is_trivial(field)? {
if only_field.take().is_some() {
break;
}
only_field = Some(&field.ty);
}
}
only_field.ok_or_else(|| {
Error::new(
Span::call_site(),
"RefCast requires a struct with a single field",
)
})
}
fn trivial_fields(fields: &Fields) -> Result<Vec<&Type>> {
let is_trivial = decide_trivial(fields)?;
let mut trivial = Vec::new();
for field in fields {
if is_trivial(field)? {
trivial.push(&field.ty);
}
}
Ok(trivial)
}
fn decide_trivial(fields: &Fields) -> Result<fn(&Field) -> Result<bool>> {
for field in fields {
if is_explicit_trivial(field)? {
return Ok(is_explicit_trivial);
}
}
Ok(is_implicit_trivial)
}
#[allow(clippy::unnecessary_wraps)] fn is_implicit_trivial(field: &Field) -> Result<bool> {
match &field.ty {
Type::Tuple(ty) => Ok(ty.elems.is_empty()),
Type::Path(ty) => {
let ident = &ty.path.segments.last().unwrap().ident;
Ok(ident == "PhantomData" || ident == "PhantomPinned")
}
_ => Ok(false),
}
}
fn is_explicit_trivial(field: &Field) -> Result<bool> {
for attr in &field.attrs {
if attr.path().is_ident("trivial") {
attr.meta.require_path_only()?;
return Ok(true);
}
}
Ok(false)
}