1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
use std::io;
use std::future::Future;
use std::{pin::Pin, task::{Context, Poll}};
use tokio::io::{AsyncRead, ReadBuf};
use crate::http::CookieJar;
use crate::{Request, Response};
/// An `async` response from a dispatched [`LocalRequest`](super::LocalRequest).
///
/// This `LocalResponse` implements [`tokio::io::AsyncRead`]. As such, if
/// [`into_string()`](LocalResponse::into_string()) and
/// [`into_bytes()`](LocalResponse::into_bytes()) do not suffice, the response's
/// body can be read directly:
///
/// ```rust
/// # #[macro_use] extern crate rocket;
/// use std::io;
///
/// use rocket::local::asynchronous::Client;
/// use rocket::tokio::io::AsyncReadExt;
/// use rocket::http::Status;
///
/// #[get("/")]
/// fn hello_world() -> &'static str {
/// "Hello, world!"
/// }
///
/// #[launch]
/// fn rocket() -> _ {
/// rocket::build().mount("/", routes![hello_world])
/// # .configure(rocket::Config::debug_default())
/// }
///
/// # async fn read_body_manually() -> io::Result<()> {
/// // Dispatch a `GET /` request.
/// let client = Client::tracked(rocket()).await.expect("valid rocket");
/// let mut response = client.get("/").dispatch().await;
///
/// // Check metadata validity.
/// assert_eq!(response.status(), Status::Ok);
/// assert_eq!(response.body().preset_size(), Some(13));
///
/// // Read 10 bytes of the body. Note: in reality, we'd use `into_string()`.
/// let mut buffer = [0; 10];
/// response.read(&mut buffer).await?;
/// assert_eq!(buffer, "Hello, wor".as_bytes());
/// # Ok(())
/// # }
/// # rocket::async_test(read_body_manually()).expect("read okay");
/// ```
///
/// For more, see [the top-level documentation](../index.html#localresponse).
pub struct LocalResponse<'c> {
_request: Box<Request<'c>>,
response: Response<'c>,
cookies: CookieJar<'c>,
}
impl<'c> LocalResponse<'c> {
pub(crate) fn new<F, O>(req: Request<'c>, f: F) -> impl Future<Output = LocalResponse<'c>>
where F: FnOnce(&'c Request<'c>) -> O + Send,
O: Future<Output = Response<'c>> + Send
{
// `LocalResponse` is a self-referential structure. In particular,
// `inner` can refer to `_request` and its contents. As such, we must
// 1) Ensure `Request` has a stable address.
//
// This is done by `Box`ing the `Request`, using only the stable
// address thereafter.
//
// 2) Ensure no refs to `Request` or its contents leak with a lifetime
// extending beyond that of `&self`.
//
// We have no methods that return an `&Request`. However, we must
// also ensure that `Response` doesn't leak any such references. To
// do so, we don't expose the `Response` directly in any way;
// otherwise, methods like `.headers()` could, in conjunction with
// particular crafted `Responder`s, potentially be used to obtain a
// reference to contents of `Request`. All methods, instead, return
// references bounded by `self`. This is easily verified by noting
// that 1) `LocalResponse` fields are private, and 2) all `impl`s
// of `LocalResponse` aside from this method abstract the lifetime
// away as `'_`, ensuring it is not used for any output value.
let boxed_req = Box::new(req);
let request: &'c Request<'c> = unsafe { &*(&*boxed_req as *const _) };
async move {
let response: Response<'c> = f(request).await;
let mut cookies = CookieJar::new(request.rocket().config());
for cookie in response.cookies() {
cookies.add_original(cookie.into_owned());
}
LocalResponse { cookies, _request: boxed_req, response, }
}
}
}
impl LocalResponse<'_> {
pub(crate) fn _response(&self) -> &Response<'_> {
&self.response
}
pub(crate) fn _cookies(&self) -> &CookieJar<'_> {
&self.cookies
}
pub(crate) async fn _into_string(mut self) -> io::Result<String> {
self.response.body_mut().to_string().await
}
pub(crate) async fn _into_bytes(mut self) -> io::Result<Vec<u8>> {
self.response.body_mut().to_bytes().await
}
#[cfg(feature = "json")]
async fn _into_json<T: Send + 'static>(self) -> Option<T>
where T: serde::de::DeserializeOwned
{
self.blocking_read(|r| serde_json::from_reader(r)).await?.ok()
}
#[cfg(feature = "msgpack")]
async fn _into_msgpack<T: Send + 'static>(self) -> Option<T>
where T: serde::de::DeserializeOwned
{
self.blocking_read(|r| rmp_serde::from_read(r)).await?.ok()
}
#[cfg(any(feature = "json", feature = "msgpack"))]
async fn blocking_read<T, F>(mut self, f: F) -> Option<T>
where T: Send + 'static,
F: FnOnce(&mut dyn io::Read) -> T + Send + 'static
{
use tokio::sync::mpsc;
use tokio::io::AsyncReadExt;
struct ChanReader {
last: Option<io::Cursor<Vec<u8>>>,
rx: mpsc::Receiver<io::Result<Vec<u8>>>,
}
impl std::io::Read for ChanReader {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
loop {
if let Some(ref mut cursor) = self.last {
if cursor.position() < cursor.get_ref().len() as u64 {
return std::io::Read::read(cursor, buf);
}
}
if let Some(buf) = self.rx.blocking_recv() {
self.last = Some(io::Cursor::new(buf?));
} else {
return Ok(0);
}
}
}
}
let (tx, rx) = mpsc::channel(2);
let reader = tokio::task::spawn_blocking(move || {
let mut reader = ChanReader { last: None, rx };
f(&mut reader)
});
loop {
// TODO: Try to fill as much as the buffer before send it off?
let mut buf = Vec::with_capacity(1024);
match self.read_buf(&mut buf).await {
Ok(n) if n == 0 => break,
Ok(_) => tx.send(Ok(buf)).await.ok()?,
Err(e) => {
tx.send(Err(e)).await.ok()?;
break;
}
}
}
// NOTE: We _must_ drop tx now to prevent a deadlock!
drop(tx);
reader.await.ok()
}
// Generates the public API methods, which call the private methods above.
pub_response_impl!("# use rocket::local::asynchronous::Client;\n\
use rocket::local::asynchronous::LocalResponse;" async await);
}
impl AsyncRead for LocalResponse<'_> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(self.response.body_mut()).poll_read(cx, buf)
}
}
impl std::fmt::Debug for LocalResponse<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self._response().fmt(f)
}
}