diff --git a/src/run_query_dsl/mod.rs b/src/run_query_dsl/mod.rs index 18a9fda..9625b43 100644 --- a/src/run_query_dsl/mod.rs +++ b/src/run_query_dsl/mod.rs @@ -146,7 +146,7 @@ pub mod return_futures { /// /// This is essentially `impl Future>` pub type GetResult<'conn, 'query, Q: LoadQuery<'query, Conn, U>, Conn, U> = - utils::AndThen, utils::LoadNext>>>>; + utils::AndThen, utils::LoadNext>>, U>>; } /// Methods used to execute queries. diff --git a/src/run_query_dsl/utils.rs b/src/run_query_dsl/utils.rs index 22f5891..3d63927 100644 --- a/src/run_query_dsl/utils.rs +++ b/src/run_query_dsl/utils.rs @@ -3,7 +3,8 @@ use std::pin::Pin; use std::task::{Context, Poll}; use diesel::QueryResult; -use futures_core::{ready, TryFuture, TryStream}; +use futures_core::{TryFuture, TryStream}; +use futures_util::stream::TryCollect; use futures_util::{TryFutureExt, TryStreamExt}; // We use a custom future implementation here to erase some lifetimes @@ -80,33 +81,48 @@ where /// Converts a stream into a future, only yielding the first element. /// Based on [`futures_util::stream::StreamFuture`]. -pub struct LoadNext { - stream: Option, +/// +/// Consumes the entire stream to ensure proper cleanup before returning which is +/// required to fix: https://github.com/weiznich/diesel_async/issues/269 +#[repr(transparent)] +pub struct LoadNext +where + F: TryStream, +{ + future: TryCollect>, } -impl LoadNext { - pub(crate) fn new(stream: St) -> Self { +impl LoadNext +where + F: TryStream, +{ + pub(crate) fn new(stream: F) -> Self { Self { - stream: Some(stream), + future: stream.try_collect(), } } } -impl Future for LoadNext +impl Future for LoadNext where - St: TryStream + Unpin, + F: TryStream, + TryCollect>: Future, diesel::result::Error>>, { - type Output = QueryResult; + type Output = QueryResult; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let first = { - let s = self.stream.as_mut().expect("polling LoadNext twice"); - ready!(s.try_poll_next_unpin(cx)) - }; - self.stream = None; - match first { - Some(first) => Poll::Ready(first), - None => Poll::Ready(Err(diesel::result::Error::NotFound)), + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match unsafe { + // SAFETY: This projects pinning to the only inner field + self.map_unchecked_mut(|s| &mut s.future) + } + .poll(cx) + { + Poll::Ready(Ok(results)) => match results.into_iter().next() { + Some(first) => Poll::Ready(Ok(first)), + None => Poll::Ready(Err(diesel::result::Error::NotFound)), + }, + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, } } }