1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
use proc_macro::TokenStream;
use devise::{Spanned, Result, ext::SpanDiagnosticExt};

use crate::syn;

#[derive(Debug)]
struct DatabaseInvocation {
    /// The attributes on the attributed structure.
    attrs: Vec<syn::Attribute>,
    /// The name of the structure on which `#[database(..)] struct This(..)` was invoked.
    type_name: syn::Ident,
    /// The visibility of the structure on which `#[database(..)] struct This(..)` was invoked.
    visibility: syn::Visibility,
    /// The database name as passed in via #[database('database name')].
    db_name: String,
    /// The type inside the structure: struct MyDb(ThisType).
    connection_type: syn::Type,
}

const EXAMPLE: &str = "example: `struct MyDatabase(diesel::SqliteConnection);`";
const ONLY_ON_STRUCTS_MSG: &str = "`database` attribute can only be used on structs";
const ONLY_UNNAMED_FIELDS: &str = "`database` attribute can only be applied to \
    structs with exactly one unnamed field";
const NO_GENERIC_STRUCTS: &str = "`database` attribute cannot be applied to structs \
    with generics";

fn parse_invocation(attr: TokenStream, input: TokenStream) -> Result<DatabaseInvocation> {
    let attr_stream2 = crate::proc_macro2::TokenStream::from(attr);
    let string_lit = crate::syn::parse2::<syn::LitStr>(attr_stream2)?;

    let input = crate::syn::parse::<syn::DeriveInput>(input).unwrap();
    if !input.generics.params.is_empty() {
        return Err(input.generics.span().error(NO_GENERIC_STRUCTS));
    }

    let structure = match input.data {
        syn::Data::Struct(s) => s,
        _ => return Err(input.span().error(ONLY_ON_STRUCTS_MSG))
    };

    let inner_type = match structure.fields {
        syn::Fields::Unnamed(ref fields) if fields.unnamed.len() == 1 => {
            let first = fields.unnamed.first().expect("checked length");
            first.ty.clone()
        }
        _ => return Err(structure.fields.span().error(ONLY_UNNAMED_FIELDS).help(EXAMPLE))
    };

    Ok(DatabaseInvocation {
        attrs: input.attrs,
        type_name: input.ident,
        visibility: input.vis,
        db_name: string_lit.value(),
        connection_type: inner_type,
    })
}

#[allow(non_snake_case)]
pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result<TokenStream> {
    let invocation = parse_invocation(attr, input)?;

    // Store everything we're going to need to generate code.
    let conn_type = &invocation.connection_type;
    let name = &invocation.db_name;
    let attrs = &invocation.attrs;
    let guard_type = &invocation.type_name;
    let vis = &invocation.visibility;
    let fairing_name = format!("'{}' Database Pool", name);
    let span = conn_type.span();

    // A few useful paths.
    let root = quote_spanned!(span => ::rocket_sync_db_pools);
    let rocket = quote!(#root::rocket);

    let request_guard_type = quote_spanned! { span =>
        #(#attrs)* #vis struct #guard_type(#root::Connection<Self, #conn_type>);
    };

    let pool = quote_spanned!(span => #root::ConnectionPool<Self, #conn_type>);
    let conn = quote_spanned!(span => #root::Connection<Self, #conn_type>);

    Ok(quote! {
        #request_guard_type

        impl #guard_type {
            /// Returns a fairing that initializes the database connection pool.
            pub fn fairing() -> impl #rocket::fairing::Fairing {
                <#pool>::fairing(#fairing_name, #name)
            }

            /// Returns an opaque type that represents the connection pool
            /// backing connections of type `Self`.
            pub fn pool<P: #rocket::Phase>(__rocket: &#rocket::Rocket<P>) -> Option<&#pool> {
                <#pool>::pool(&__rocket)
            }

            /// Runs the provided function `__f` in an async-safe blocking
            /// thread.
            pub async fn run<F, R>(&self, __f: F) -> R
                where F: FnOnce(&mut #conn_type) -> R + Send + 'static,
                      R: Send + 'static,
            {
                self.0.run(__f).await
            }

            /// Retrieves a connection of type `Self` from the `rocket` instance.
            pub async fn get_one<P: #rocket::Phase>(__rocket: &#rocket::Rocket<P>) -> Option<Self> {
                <#pool>::get_one(&__rocket).await.map(Self)
            }
        }

        #[#rocket::async_trait]
        impl<'r> #rocket::request::FromRequest<'r> for #guard_type {
            type Error = ();

            async fn from_request(
                __r: &'r #rocket::request::Request<'_>
            ) -> #rocket::request::Outcome<Self, ()> {
                <#conn>::from_request(__r).await.map(Self)
            }
        }

        impl #rocket::Sentinel for #guard_type {
            fn abort(__r: &#rocket::Rocket<#rocket::Ignite>) -> bool {
                <#conn>::abort(__r)
            }
        }
    }.into())
}