diff --git a/Cargo.lock b/Cargo.lock index 0d4cc94..eaf571d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,12 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "async-iterator" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "742b2f12ff517f144b6181d24f3f2481b503e05650ee79feec1f090048089f88" + [[package]] name = "atomic-waker" version = "1.1.2" @@ -99,6 +105,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" +[[package]] +name = "constcat" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "136d3e02915a2cea4d74caa8681e2d44b1c3254bdbf17d11d41d587ff858832c" + [[package]] name = "core-foundation" version = "0.9.4" @@ -664,9 +676,11 @@ dependencies = [ [[package]] name = "multipart_async_stream" -version = "0.1.1" +version = "0.2.2" dependencies = [ + "async-iterator", "bytes 1.10.1", + "constcat", "futures-util", "http", "httparse", @@ -1160,18 +1174,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.15" +version = "2.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80d76d3f064b981389ecb4b6b7f45a0bf9fdac1d5b9204c7bd6714fecc302850" +checksum = "3467d614147380f2e4e374161426ff399c91084acd2363eaf549172b3d5e60c0" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "2.0.15" +version = "2.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44d29feb33e986b6ea906bd9c3559a856983f92371b3eaa5e83782a351623de0" +checksum = "6c5e1be1c48b9172ee610da68fd9cd2770e7a4056cb3fc98710ee6906f0c7960" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 37e7edd..94383ed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "multipart_async_stream" -version = "0.1.1" +version = "0.2.2" edition = "2024" license = "MIT OR Apache-2.0" description = "An easy-to-use, efficient, and asynchronous multipart stream parser." @@ -14,12 +14,14 @@ categories = [ ] [dependencies] +async-iterator = "2.3.0" bytes = "1.10.1" +constcat = "0.6.1" futures-util = { version = "0.3.31", features = ["tokio-io"] } http = "1.3.1" httparse = "1.10.1" memchr = "2.7.5" -thiserror = "2.0.15" +thiserror = "2.0.16" [dev-dependencies] reqwest = { version = "0.12.23", features = ["stream"] } diff --git a/README.md b/README.md index 0acf9c1..26b1b4c 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,6 @@ # Multipart Stream -![alt text](https://img.shields.io/crates/v/multipart_async_stream.svg) - - -![alt text](https://docs.rs/multipart_async_stream/badge.svg) - - -![alt text](https://github.com/OpenTritium/multipart_stream/actions/workflows/ci.yaml/badge.svg) +![alt text](https://img.shields.io/crates/v/multipart_async_stream.svg) ![alt text](https://docs.rs/multipart_async_stream/badge.svg) ![alt text](https://github.com/OpenTritium/multipart_stream/actions/workflows/ci.yaml/badge.svg) This library is designed as an adapter for `futures_util::TryStream`, allowing for easy parsing of an incoming byte stream (such as from an HTTP response) and splitting it into multiple parts (`Part`). It is especially useful for handling `multipart/byteranges` HTTP responses. @@ -14,17 +8,13 @@ A common use case is sending an HTTP Range request to a server and then parsing The example below demonstrates how to use reqwest to download multiple ranges of a file and parse the individual parts using multipart_stream. ```rust -use http::header::CONTENT_TYPE; +use multipart_async_stream::{LendingIterator, MultipartStream, TryStreamExt, header::CONTENT_TYPE}; #[tokio::main] async fn main() { const URL: &str = "https://mat1.gtimg.com/pingjs/ext2020/newom/build/static/images/new_logo.png"; let client = reqwest::Client::new(); - let response = client - .get(URL) - .header("Range", "bytes=0-32,64-128") - .send() - .await.unwrap(); + let response = client.get(URL).header("Range", "bytes=0-31,64-127").send().await.unwrap(); let boundary = response .headers() .get(CONTENT_TYPE) @@ -33,9 +23,23 @@ async fn main() { .and_then(|s| s.split("boundary=").nth(1)) .map(|s| s.trim().as_bytes().to_vec().into_boxed_slice()); let s = response.bytes_stream(); - let mut m = multipart_async_stream::MultipartStream::new(s, &boundary.unwrap()); - while let Ok(x) = m.try_next().await { - println!("Part: {x:?}"); + let mut m = MultipartStream::new(s, &boundary.unwrap()); + + while let Some(Ok(part)) = m.next().await { + println!("{:?}", part.headers()); + let mut body = part.body(); + while let Ok(Some(b)) = body.try_next().await { + println!("{:?}", b); + } } } +``` + +The output of the program above is: + +```bash +{"content-type": "image/png", "content-range": "bytes 0-31/10845"} +b"\x89PNG\r\n\x1a\n\0\0\0\rIHDR\0\0\0\xf4\0\0\0B\x08\x06\0\0\0`\xbc\xfb" +{"content-type": "image/png", "content-range": "bytes 64-127/10845"} +b"L:com.adobe.xmp\0\0\0\0\0), #[error("parse error: {0}")] ParseError(#[from] ParseError), + #[error("body stream is not consumed")] + BodyNotConsumed, } /// 表示解析器当前所处的状态 #[derive(Debug)] enum ParserState { - Preamble(usize), // 找到头的边界,移动缓冲区指针至 hdr 初始位置 - ReadingHeaders(usize), // 正在读取头的内容 - ReadingBody { headers: Box, scan: usize }, // 正在读取 body + Preamble(usize), // 找到头的边界,移动缓冲区指针至 hdr 初始位置 + ReadingHeaders(usize), // 正在读取头的内容 + StreamingBody(usize), /* 移动最后一个窗口的内容,下次拼接头以后再判断, + * 但是这样还是要拷贝 */ Finished, } #[derive(Error, Debug)] pub enum ParseError { #[error(transparent)] - Try(#[from] httparse::Error), + Other(#[from] httparse::Error), #[error("buffer no cahnge")] BufferNoChange, - #[error("")] + #[error("incomplete headers content")] TryParsePartial, } -enum ParseResult { - Partial, - Full(Part), - Failed(ParseError), - Completed, -} - -impl From for ParseResult { - fn from(err: httparse::Error) -> Self { ParseResult::Failed(ParseError::Try(err)) } -} - +const CRLF: &[u8] = b"\r\n"; +const DOUBLE_HYPEN: &[u8] = b"--"; pub struct MultipartStream where S: TryStream + Unpin, @@ -56,8 +56,7 @@ where rx: S, terminated: bool, state: ParserState, - boundary_pattern: Box<[u8]>, - boundary_finder: Finder<'static>, + pattern: Box<[u8]>, // `\r\n -- boundary \r\n` header_body_splitter_finder: Finder<'static>, header_body_splitter_len: usize, buf: BytesMut, @@ -69,191 +68,235 @@ where S::Error: std::error::Error + Send + Sync + 'static, { pub fn new(stream: S, boundary: &[u8]) -> Self { - let mut buf = Vec::with_capacity(boundary.len() + 2); - buf.extend_from_slice(b"--"); - buf.extend_from_slice(boundary); - let boundary_pattern = buf.into_boxed_slice(); - let boundary_finder = unsafe { - // 我们知道 'pattern' 会和 'Self' 实例活得一样久, - // 所以将它的生命周期转换为 'static' 在这里是安全的。 - let static_pattern: &'static [u8] = mem::transmute(&*boundary_pattern); - Finder::new(static_pattern) - }; - const HEADER_BODY_SPLITTER: &[u8] = b"\r\n\r\n"; + let mut pattern = Vec::with_capacity(boundary.len() + 2 * CRLF.len() + 2 * DOUBLE_HYPEN.len()); + pattern.extend_from_slice(CRLF); + pattern.extend_from_slice(DOUBLE_HYPEN); + pattern.extend_from_slice(boundary); + pattern.extend_from_slice(CRLF); + const HEADER_BODY_SPLITTER: &[u8] = concat_bytes!(CRLF, CRLF); Self { rx: stream, terminated: false, state: ParserState::Preamble(0), buf: BytesMut::new(), - boundary_finder, header_body_splitter_finder: Finder::new(HEADER_BODY_SPLITTER), header_body_splitter_len: HEADER_BODY_SPLITTER.len(), - boundary_pattern, + pattern: pattern.into(), } } - /// 从缓冲区中解析 Part,副作用是会消耗 buf,更改内部 state - fn parse_buf(&mut self) -> ParseResult { - use ParseResult::*; + fn update_scan(&mut self, new_scan: usize) { use ParserState::*; - let state = &mut self.state; - let buf = &mut self.buf; - match state { - &mut Preamble(scan) => { - if buf.len() < self.boundary_pattern.len() + scan { - // 还没有足够的字节来匹配边界 - return Partial; - } - if let Some(pos) = self.boundary_finder.find(&buf[scan..]) { - let total_advance_len = scan + pos + self.boundary_pattern.len() + 2; // +2 是因为边界和 headers 间有一个 `\r\n` - if buf.len() < total_advance_len { - // 找到了 boundary,但是还需要判断后续接收是否还有两个字节 - return Partial; - } - buf.advance(pos + self.boundary_pattern.len() + 2); - *state = ReadingHeaders(0); - } else { - // 扫描只会进行到最后一个满足窗口大小的窗口,所以将 scan 指定到最后满足最后一个窗口的位置之后 - let new_pos = buf.len() - self.boundary_pattern.len() + 1; - if new_pos == scan { - return Failed(ParseError::BufferNoChange); - } - *state = Preamble(new_pos); - }; - Partial + match &mut self.state { + Preamble(scan) | ReadingHeaders(scan) | StreamingBody(scan) => { + debug_assert!(new_scan > *scan); + *scan = new_scan } - &mut ReadingHeaders(scan) => { - // RFC 2046 规定了使用 CRLF - if buf.len() < self.header_body_splitter_len + scan { - // 还没有足够的字节来匹配边界 - return Partial; - } - if let Some(pos) = self.header_body_splitter_finder.find(&buf[scan..]) { - let offset = scan + pos + self.header_body_splitter_len; - let hdrs_contnet = buf.split_to(offset).freeze(); - let mut hdrs_buf = [EMPTY_HEADER; 64]; - match parse_headers(&hdrs_contnet, &mut hdrs_buf) { - Ok(Status::Complete(_)) => {} - Ok(Status::Partial) => return Failed(ParseError::TryParsePartial), - Err(err) => return Failed(err.into()), - } - let headers = hdrs_buf - .iter() - .take_while(|hdr| hdr.name.is_empty().not()) - .filter_map(|hdr| { - let name = HeaderName::from_str(hdr.name); - let value = HeaderValue::from_bytes(hdr.value); - name.ok().zip(value.ok()) - }) - .collect::(); - *state = ReadingBody { headers: Box::new(headers), scan: 0 }; - } else { - // 指定新的待扫描位置,依然是刚好最后一个窗口之后 - let new_pos = buf.len() - self.header_body_splitter_len + 1; - if new_pos == scan { - return Failed(ParseError::BufferNoChange); - } - *state = ReadingHeaders(new_pos); - }; - Partial - } - &mut ReadingBody { ref mut headers, scan } => { - if buf.len() < self.boundary_pattern.len() + scan { - // 还没有足够的字节来匹配边界 - return Partial; - } - if let Some(pos) = self.boundary_finder.find(&buf[scan..]) { - let body_end = scan + pos; - let tail = { - // 匹配 `-- boundary --` 最后的两根 `-` - let pos = body_end + self.boundary_pattern.len(); - buf.get(pos..pos + 2) - }; - let mut split_part = |buf: &mut BytesMut| { - println!("{body_end}"); - let body = buf.split_to(body_end.saturating_sub(2)).freeze(); // forget CRLF and handle on empty body - Part::new(mem::take(headers), body) + Finished => unreachable!("cannot invoke add_scan on finished state"), /* 几乎不可能会在完成状态继续update + * scan */ + } + } + + // 在处于 非 Streaming body 状态下均返回 none + fn poll_next_body_chunk(&mut self, cx: &mut Context<'_>) -> Poll>> { + use ParserState::*; + use Poll::*; + let pattern_len = self.pattern.len(); + let sub_pattern_len = pattern_len - 2; + loop { + let prev_buf_len = self.buf.len(); + let scan = match self.state { + Preamble(_) | ReadingHeaders(_) | Finished => return Ready(None), + StreamingBody(scan) => scan, + }; + if prev_buf_len >= pattern_len + scan { + // \r\n--boundary (\r\n | --) + if let Some(pos) = Finder::new(&self.pattern[..sub_pattern_len]).find(&self.buf[scan..]) { + let pattern_start = scan + pos; + let pattern_tail = { + let pos = pattern_start + sub_pattern_len; + self.buf.get(pos..pos + 2) }; - match tail { - Some(b"--") => { - let part = split_part(buf); - // 匹配到结束边界以后就可以将状态设置为完成了,下次调用此函数会返回 Completed - *state = Finished; - Full(part) + match pattern_tail { + Some(CRLF) => { + // multipart 的流没有结束,开始下一个 part headers 的解析, + // 此时立刻调用此函数只会返回 none + self.state = Preamble(0); + let chunk = self.buf.split_to(pattern_start).freeze(); + return Ready(Some(Ok(chunk))); + } + Some(DOUBLE_HYPEN) => { + // multipart 流已经结束,同时意味着也不会有 body 流了 + // 下次调用此函数只会返回 none + self.state = Finished; + let chunk = self.buf.split_to(pattern_start).freeze(); + self.buf.clear(); // 跳过 `-- boundary --` + return Ready(Some(Ok(chunk))); } - // 没有到真正的结尾 Some(_) => { - let part = split_part(buf); - *state = Preamble(0); - Full(part) + // 恰好有和模式一样的内容在 body 中 + let new_scan = self.buf.len() - sub_pattern_len + 1; + if new_scan == scan { + return Ready(Some(Err(ParseError::BufferNoChange.into()))); + } + self.update_scan(new_scan); } - // 把后面的两字节接收了再来判断 - None => Partial, + // 继续接收来判断后两个字节 + None => {} } } else { - let new_pos = buf.len() - self.boundary_pattern.len() + 1; - if new_pos == scan { - return Failed(ParseError::BufferNoChange); + let new_scan = self.buf.len() - sub_pattern_len + 1; + if new_scan == scan { + return Ready(Some(Err(ParseError::BufferNoChange.into()))); } - *state = ReadingBody { headers: mem::take(headers), scan: new_pos }; - Partial + self.update_scan(new_scan); } } - Finished => Completed, + + // streaming body 状态下,终止则返回早终止 + if self.terminated && self.buf.len() == prev_buf_len { + return Ready(Some(Err(Error::EarlyTerminate))); + } + return match self.rx.try_poll_next_unpin(cx) { + Ready(Some(Ok(chunk))) => { + self.buf.extend_from_slice(&chunk); + continue; + } + Ready(Some(Err(err))) => Ready(Some(Err(Error::StreamError(Box::new(err))))), + Ready(None) => { + self.terminated = true; + continue; + } + Pending => Pending, + }; } } - pub async fn try_next(&mut self) -> Result { + fn poll_next_part(&'_ mut self, cx: &mut Context<'_>) -> Poll, Error>>> { loop { - use ParseResult::*; - // 当流没有终止时才接收,不然会阻塞 - if self.terminated.not() { - // 尝试填充缓冲区 - match self.rx.try_next().await { - Ok(Some(payload)) => { - self.buf.extend_from_slice(&payload); - } - Ok(None) => { - // 流终止 - self.terminated = true; + use ParserState::*; + use Poll::*; + let prev_buf_len = self.buf.len(); + let pattern_no_crlf_len = self.pattern.len() - 2; + let pattern_no_start_crlf = &self.pattern[2..]; + match self.state { + // --boundary\r\n + Preamble(scan) if prev_buf_len >= pattern_no_crlf_len + scan => { + if let Some(pos) = Finder::new(pattern_no_start_crlf).find(&self.buf[scan..]) { + let total_advance_len = scan + pos + pattern_no_crlf_len; + // 如果 advance 长度大于 当前缓冲区长度就继续接收 + if self.buf.len() >= total_advance_len { + self.buf.advance(total_advance_len); + self.state = ReadingHeaders(0); + } + } else { + // 扫描只会进行到最后一个满足窗口大小的窗口,所以将 scan 指定到最后满足最后一个窗口的位置之后 + let new_scan = prev_buf_len - pattern_no_crlf_len + 1; + if new_scan == scan { + return Ready(Some(Err(ParseError::BufferNoChange.into()))); + } + self.update_scan(new_scan); } - Err(err) => return Err(Error::StreamError(Box::new(err))), } + // CRLFCRLF + ReadingHeaders(scan) if prev_buf_len >= self.header_body_splitter_len + scan => { + if let Some(pos) = self.header_body_splitter_finder.find(&self.buf[scan..]) { + let hdrs_end = scan + pos + self.header_body_splitter_len; + let hdrs_contnet = &self.buf[..hdrs_end]; // 两个 CRLF 也要纳入解析 + let mut hdrs_buf = [EMPTY_HEADER; 64]; + match parse_headers(hdrs_contnet, &mut hdrs_buf) { + Ok(Status::Complete(_)) => {} + Ok(Status::Partial) => return Ready(Some(Err(ParseError::TryParsePartial.into()))), + Err(err) => return Ready(Some(Err(ParseError::Other(err).into()))), + } + let headers = hdrs_buf + .iter() + .take_while(|hdr| hdr.name.is_empty().not()) + .filter_map(|hdr| { + let name = HeaderName::from_str(hdr.name); + let value = HeaderValue::from_bytes(hdr.value); + name.ok().zip(value.ok()) + }) + .collect::(); + self.buf.advance(hdrs_end); + self.state = StreamingBody(0); + return Ready(Some(Ok(Part::new(self, headers.into())))); + } else { + // 指定新的待扫描位置,依然是刚好最后一个窗口之后 + let new_scan = self.buf.len() - self.header_body_splitter_len + 1; + if new_scan == scan { + return Ready(Some(Err(ParseError::BufferNoChange.into()))); + } + self.update_scan(new_scan); + }; + } + Finished => return Ready(None), + StreamingBody(_) => return Ready(Some(const { Err(Error::BodyNotConsumed) })), + _ => {} } - let prev_buf_len = self.buf.len(); - match self.parse_buf() { - Full(part) => return Ok(part), - Completed => { - return Err(Error::Eof); + if self.terminated && self.buf.len() == prev_buf_len { + return Ready(Some(Err(Error::EarlyTerminate))); + } + return match self.rx.try_poll_next_unpin(cx) { + Ready(Some(Ok(chunk))) => { + self.buf.extend_from_slice(&chunk); + continue; } - Failed(err) => { - return Err(err.into()); + Ready(Some(Err(err))) => Ready(Some(Err(Error::StreamError(Box::new(err))))), + Ready(None) => { + self.terminated = true; + continue; } - Partial => { - if self.terminated && self.buf.len() == prev_buf_len { - return Err(Error::EarlyTerminate); - } - } // 只知道它没解析完,但是后面可能再循环几次就处理完了 - } + Pending => Pending, + }; } } } -#[derive(Debug, Clone)] -pub struct Part { +impl LendingIterator for MultipartStream +where + S: TryStream + Unpin, + S::Error: std::error::Error + Send + Sync + 'static, +{ + type Item<'a> + = Result, Error> + where + S: 'a; + + async fn next(&mut self) -> Option> { + let this = self as *mut Self; + let result = futures_util::future::poll_fn(move |cx| { + let this = unsafe { &mut *this }; + this.poll_next_part(cx) + }) + .await; + unsafe { mem::transmute::, Error>>, Option, Error>>>(result) } + } +} + +pub struct Part<'a, S> +where + S: TryStream + Unpin, + S::Error: std::error::Error + Send + Sync + 'static, +{ + body: &'a mut MultipartStream, headers: Box, - body: Bytes, } -impl Part { - #[inline(always)] - fn new(headers: Box, body: Bytes) -> Self { Self { headers, body } } +impl<'a, S> Part<'a, S> +where + S: TryStream + Unpin, + S::Error: std::error::Error + Send + Sync + 'static, +{ + fn new(stream: &'a mut MultipartStream, headers: Box) -> Self { Self { body: stream, headers } } - #[inline(always)] pub fn headers(&self) -> &HeaderMap { &self.headers } - #[inline(always)] - pub fn body(&self) -> &Bytes { &self.body } + pub fn into_headers(self) -> HeaderMap { *self.headers } + + pub fn body(self) -> impl TryStream + 'a { + futures_util::stream::poll_fn(move |cx| self.body.poll_next_body_chunk(cx)) + } } #[cfg(test)] @@ -270,6 +313,15 @@ mod tests { stream::iter(chunks) } + async fn concat_body(s: impl TryStream) -> Vec { + s.try_fold(vec![], |mut acc, chunk| async move { + acc.extend_from_slice(&chunk); + Ok(acc) + }) + .await + .unwrap() + } + #[tokio::test] async fn test_single_part_full_chunk() { const BOUNDARY: &str = "boundary"; @@ -279,97 +331,92 @@ Content-Disposition: form-data; name=\"field1\"\r\n\ \r\n\ value1\r\n\ --boundary--\r\n"; - let stream = create_stream_from_chunks(CONTENT, CONTENT.len()); - let mut multipart_stream = MultipartStream::new(stream, BOUNDARY.as_bytes()); - // 解析第一个部分 - // let x = multipart_stream.try_next().await; - let part = multipart_stream.try_next().await.unwrap(); - assert_eq!(part.headers().get("content-disposition").unwrap(), "form-data; name=\"field1\""); - assert_eq!(part.body(), &Bytes::from_static(b"value1")); - - // 应该已经到达流的末尾 - let result = multipart_stream.try_next().await; - assert!(matches!(result, Err(Error::Eof))); + let mut m = MultipartStream::new(stream, BOUNDARY.as_bytes()); + while let Some(Ok(part)) = m.next().await { + assert_eq!(part.headers().get("content-disposition").unwrap(), "form-data; name=\"field1\""); + assert_eq!(&concat_body(part.body()).await, b"value1") + } } #[tokio::test] async fn test_multiple_parts_small_chunks() { const BOUNDARY: &str = "X-BOUNDARY"; const BODY: &[u8] = b"\ ---X-BOUNDARY\r\n\ -Content-Disposition: form-data; name=\"field1\"\r\n\ -\r\n\ -value1\r\n\ ---X-BOUNDARY\r\n\ -Content-Disposition: form-data; name=\"field2\"\r\n\ -Content-Type: text/plain\r\n\ -\r\n\ -value2 with CRLF\r\n\r\n\ ---X-BOUNDARY--\r\n"; + --X-BOUNDARY\r\n\ + Content-Disposition: form-data; name=\"field1\"\r\n\ + \r\n\ + value1\r\n\ + --X-BOUNDARY\r\n\ + Content-Disposition: form-data; name=\"field2\"\r\n\ + Content-Type: text/plain\r\n\ + \r\n\ + value2 with CRLF\r\n\r\n\ + --X-BOUNDARY--\r\n"; // 使用一个很小的块大小来强制测试缓冲逻辑 let stream = create_stream_from_chunks(BODY, 5); let mut multipart_stream = MultipartStream::new(stream, BOUNDARY.as_bytes()); // 解析第一部分 - let part1 = multipart_stream.try_next().await.unwrap(); + let part1 = multipart_stream.next().await.unwrap().unwrap(); assert_eq!(part1.headers().get("content-disposition").unwrap(), "form-data; name=\"field1\""); assert!(!part1.headers().contains_key("content-type")); - assert_eq!(part1.body(), &Bytes::from_static(b"value1")); + assert_eq!(&concat_body(part1.body()).await, b"value1"); // 解析第二部分 - let part2 = multipart_stream.try_next().await.unwrap(); + let part2 = multipart_stream.next().await.unwrap().unwrap(); assert_eq!(part2.headers().get("content-disposition").unwrap(), "form-data; name=\"field2\""); assert_eq!(part2.headers().get("content-type").unwrap(), "text/plain"); - assert_eq!(part2.body(), &Bytes::from_static(b"value2 with CRLF\r\n")); - + let body = concat_body(part2.body()).await; + assert_eq!(&body, b"value2 with CRLF\r\n"); // 应该已经到达流的末尾 - let result = multipart_stream.try_next().await; - assert!(matches!(result, Err(Error::Eof))); + let result = multipart_stream.next().await; + assert!(result.is_none()); } #[tokio::test] async fn test_with_preamble_and_no_final_crlf() { const BOUNDARY: &str = "boundary"; const BODY: &[u8] = b"\ -This is a preamble and should be ignored.\r\n\ ---boundary\r\n\ -Content-Disposition: form-data; name=\"field1\"\r\n\ -\r\n\ -value1\r\n\ ---boundary--"; // 注意:末尾没有 `\r\n` + This is a preamble and should be ignored.\r\n\ + --boundary\r\n\ + Content-Disposition: form-data; name=\"field1\"\r\n\ + \r\n\ + value1\r\n\ + --boundary--"; // 注意:末尾没有 `\r\n` let stream = create_stream_from_chunks(BODY, BODY.len()); let mut multipart_stream = MultipartStream::new(stream, BOUNDARY.as_bytes()); // 解析第一个部分 - // let _ = multipart_stream.try_next().await; - let part = multipart_stream.try_next().await.unwrap(); + let part = multipart_stream.next().await.unwrap().unwrap(); assert_eq!(part.headers().get("content-disposition").unwrap(), "form-data; name=\"field1\""); - assert_eq!(part.body(), &Bytes::from_static(b"value1")); - + let body = concat_body(part.body()).await; + assert_eq!(&body, b"value1"); // 应该已经到达流的末尾 - let result = multipart_stream.try_next().await; - assert!(matches!(result, Err(Error::Eof))); + let result = multipart_stream.next().await; + assert!(result.is_none()); } #[tokio::test] + #[should_panic(expected = "EarlyTerminate")] async fn test_early_terminate_in_body() { const BOUNDARY: &str = "boundary"; // 消息在 body 中被截断,没有结束边界 const BODY: &[u8] = b"\ ---boundary\r\n\ -Content-Disposition: form-data; name=\"field1\"\r\n\ -\r\n\ -value1 is not complete"; + --boundary\r\n\ + Content-Disposition: form-data; name=\"field1\"\r\n\ + \r\n\ + value1 is not complete"; let stream = create_stream_from_chunks(BODY, BODY.len()); let mut multipart_stream = MultipartStream::new(stream, BOUNDARY.as_bytes()); // 解析应该会失败,因为流在找到下一个边界前就终止了 - let result = multipart_stream.try_next().await; - assert!(matches!(result, Err(Error::EarlyTerminate))); + let part = multipart_stream.next().await.unwrap().unwrap(); + + let _ = concat_body(part.body()).await; } #[tokio::test] @@ -377,15 +424,15 @@ value1 is not complete"; const BOUNDARY: &str = "boundary"; // 消息在 headers 中被截断 const BODY: &[u8] = b"\ ---boundary\r\n\ -Content-Disposition: form-data; na"; + --boundary\r\n\ + Content-Disposition: form-data; na"; let stream = create_stream_from_chunks(BODY, BODY.len()); let mut multipart_stream = MultipartStream::new(stream, BOUNDARY.as_bytes()); // 解析应该会失败,因为流在 headers 结束前就终止了 - let result = multipart_stream.try_next().await; - assert!(matches!(result, Err(Error::EarlyTerminate))); + let result = multipart_stream.next().await; + assert!(matches!(result, Some(Err(Error::EarlyTerminate)))); } #[tokio::test] @@ -397,41 +444,112 @@ Content-Disposition: form-data; na"; let mut multipart_stream = MultipartStream::new(stream, BOUNDARY.as_bytes()); // 对于空流,应该提前终止 - let result = multipart_stream.try_next().await; - assert!(matches!(result, Err(Error::EarlyTerminate))); + let result = multipart_stream.next().await; + assert!(matches!(result, Some(Err(Error::EarlyTerminate)))); } + #[tokio::test] async fn test_part_with_empty_body() { const BOUNDARY: &str = "boundary"; const BODY: &[u8] = b"\ + --boundary\r\n\ + Content-Disposition: form-data; name=\"field1\"\r\n\ + \r\n\ + value1\r\n\ + --boundary\r\n\ + Content-Disposition: form-data; name=\"empty_field\"\r\n\ + \r\n\ + \r\n\ + --boundary\r\n\ + Content-Disposition: form-data; name=\"field2\"\r\n\ + \r\n\ + value2\r\n\ + --boundary--\r\n"; + + let stream = create_stream_from_chunks(BODY, 15); // Usar chunks pequeños + let mut multipart_stream = MultipartStream::new(stream, BOUNDARY.as_bytes()); + let part1 = multipart_stream.next().await.unwrap().unwrap(); + assert_eq!(part1.headers().get("content-disposition").unwrap(), "form-data; name=\"field1\""); + assert_eq!(&concat_body(part1.body()).await, b"value1"); + + let part2 = multipart_stream.next().await.unwrap().unwrap(); + assert_eq!(part2.headers().get("content-disposition").unwrap(), "form-data; name=\"empty_field\""); + let body = concat_body(part2.body()).await; + assert!(body.is_empty()); + + let part3 = multipart_stream.next().await.unwrap().unwrap(); + assert_eq!(part3.headers().get("content-disposition").unwrap(), "form-data; name=\"field2\""); + assert_eq!(&concat_body(part3.body()).await, b"value2"); + + let result = multipart_stream.next().await; + assert!(result.is_none()); + } + + #[tokio::test] + async fn test_body_not_consumed_error() { + const BOUNDARY: &str = "boundary"; + const BODY: &[u8] = b"\ --boundary\r\n\ Content-Disposition: form-data; name=\"field1\"\r\n\ \r\n\ value1\r\n\ --boundary\r\n\ -Content-Disposition: form-data; name=\"empty_field\"\r\n\ -\r\n\ ---boundary\r\n\ Content-Disposition: form-data; name=\"field2\"\r\n\ \r\n\ value2\r\n\ --boundary--\r\n"; - let stream = create_stream_from_chunks(BODY, 15); // Usar chunks pequeños - let mut multipart_stream = MultipartStream::new(stream, BOUNDARY.as_bytes()); - let part1 = multipart_stream.try_next().await.unwrap(); - assert_eq!(part1.headers().get("content-disposition").unwrap(), "form-data; name=\"field1\""); - assert_eq!(part1.body(), &Bytes::from_static(b"value1")); + let stream = create_stream_from_chunks(BODY, BODY.len()); + let mut m = MultipartStream::new(stream, BOUNDARY.as_bytes()); - let part2 = multipart_stream.try_next().await.unwrap(); - assert_eq!(part2.headers().get("content-disposition").unwrap(), "form-data; name=\"empty_field\""); - assert!(part2.body().is_empty()); + // Obtener la primera parte, pero no consumir su cuerpo + let _part1 = m.next().await.unwrap().unwrap(); - let part3 = multipart_stream.try_next().await.unwrap(); - assert_eq!(part3.headers().get("content-disposition").unwrap(), "form-data; name=\"field2\""); - assert_eq!(part3.body(), &Bytes::from_static(b"value2")); + // Intentar obtener la siguiente parte inmediatamente debe fallar + let result = m.next().await; + assert!(matches!(result, Some(Err(Error::BodyNotConsumed)))); + } - let result = multipart_stream.try_next().await; - assert!(matches!(result, Err(Error::Eof))); + #[tokio::test] + async fn test_boundary_like_string_in_body() { + const BOUNDARY: &str = "boundary"; + const BODY: &[u8] = b"\ +--boundary\r\n\ +Content-Disposition: form-data; name=\"field1\"\r\n\ +\r\n\ +value1 contains --boundary text\r\n\ +--boundary--\r\n"; + + let stream = create_stream_from_chunks(BODY, 20); + let mut m = MultipartStream::new(stream, BOUNDARY.as_bytes()); + + let part = m.next().await.unwrap().unwrap(); + let body = concat_body(part.body()).await; + assert_eq!(&body, b"value1 contains --boundary text"); + + assert!(m.next().await.is_none()); + } + + #[tokio::test] + async fn test_malformed_headers() { + const BOUNDARY: &str = "boundary"; + // "Invalid Header" contiene un espacio, lo cual es ilegal para un nombre de cabecera. + const BODY: &[u8] = b"\ +--boundary\r\n\ +Invalid Header: value\r\n\ +\r\n\ +body\r\n\ +--boundary--\r\n"; + + let stream = create_stream_from_chunks(BODY, BODY.len()); + let mut m = MultipartStream::new(stream, BOUNDARY.as_bytes()); + + let result = m.next().await.unwrap(); + assert!(matches!(result, Err(Error::ParseError(_)))); + if let Err(Error::ParseError(ParseError::Other(e))) = result { + assert_eq!(e, httparse::Error::HeaderName); + } else { + panic!("Expected a ParseError::Other with InvalidHeaderName"); + } } }