Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 26 additions & 16 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ impl<C: Connector> EppConnection<C> {
cx: &mut Context<'_>,
) -> Result<Transition, Error> {
match &mut state {
RequestState::Writing { mut start, buf } => {
let wrote = match Pin::new(&mut self.stream).poll_write(cx, &buf[start..]) {
RequestState::Writing { start, buf } => {
let wrote = match Pin::new(&mut self.stream).poll_write(cx, &buf[*start..]) {
Poll::Ready(Ok(wrote)) => wrote,
Poll::Ready(Err(err)) => return Err(err.into()),
Poll::Pending => return Ok(Transition::Pending(state)),
Expand All @@ -121,7 +121,7 @@ impl<C: Connector> EppConnection<C> {
.into());
}

start += wrote;
let start = *start + wrote;
debug!(
"{}: Wrote {} bytes, {} out of {} done",
self.registry,
Expand All @@ -133,16 +133,19 @@ impl<C: Connector> EppConnection<C> {
// Transition to reading the response's frame header once
// we've written the entire request
if start < buf.len() {
return Ok(Transition::Next(state));
return Ok(Transition::Next(RequestState::Writing {
start,
buf: mem::take(buf),
}));
}

Ok(Transition::Next(RequestState::ReadLength {
read: 0,
buf: vec![0; 256],
}))
}
RequestState::ReadLength { mut read, buf } => {
let mut read_buf = ReadBuf::new(&mut buf[read..]);
RequestState::ReadLength { read, buf } => {
let mut read_buf = ReadBuf::new(&mut buf[*read..]);
match Pin::new(&mut self.stream).poll_read(cx, &mut read_buf) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(err)) => return Err(err.into()),
Expand All @@ -162,26 +165,29 @@ impl<C: Connector> EppConnection<C> {
// The frame header is a 32-bit (4-byte) big-endian unsigned integer. If we don't
// have 4 bytes yet, stay in the `ReadLength` state, otherwise we transition to `Reading`.

read += filled.len();
if read < 4 {
return Ok(Transition::Next(state));
let new_read = *read + filled.len();
if new_read < 4 {
return Ok(Transition::Next(RequestState::ReadLength {
read: new_read,
buf: mem::take(buf),
}));
}

let expected = u32::from_be_bytes(filled[..4].try_into()?) as usize;
debug!("{}: Expected response length: {}", self.registry, expected);
buf.resize(expected, 0);
Ok(Transition::Next(RequestState::Reading {
read,
read: new_read,
buf: mem::take(buf),
expected,
}))
}
RequestState::Reading {
mut read,
read,
buf,
expected,
} => {
let mut read_buf = ReadBuf::new(&mut buf[read..]);
let mut read_buf = ReadBuf::new(&mut buf[*read..]);
match Pin::new(&mut self.stream).poll_read(cx, &mut read_buf) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(err)) => return Err(err.into()),
Expand All @@ -197,20 +203,24 @@ impl<C: Connector> EppConnection<C> {
.into());
}

read += filled.len();
let new_read = *read + filled.len();
debug!(
"{}: Read {} bytes, {} out of {} done",
self.registry,
filled.len(),
read,
new_read,
expected
);

//

Ok(if read < *expected {
Ok(if new_read < *expected {
// If we haven't received the entire response yet, stick to the `Reading` state.
Transition::Next(state)
Transition::Next(RequestState::Reading {
read: new_read,
buf: mem::take(buf),
expected: *expected,
})
} else if let Some(next) = self.next.take() {
// Otherwise, if we were just pushing through this request because it was already
// in flight when we started a new one, ignore this response and move to the
Expand Down
74 changes: 74 additions & 0 deletions tests/chunked_read.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
use std::time::Duration;

use async_trait::async_trait;
use tokio_test::io::Builder;

use instant_epp::client::{Connector, EppClient};
use instant_epp::Error;

fn len_bytes(bytes: &[u8]) -> [u8; 4] {
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, some minor remaining nits: please order these in top-down fashion; first the #[test] functions, then connect_with_chunks(). Suggest inlining len_bytes() and changing greeting to be a const (this should be at the bottom).

((bytes.len() as u32) + 4).to_be_bytes()
}

fn greeting() -> &'static [u8] {
br#"<?xml version="1.0" encoding="UTF-8"?>
<epp xmlns="urn:ietf:params:xml:ns:epp-1.0">
<greeting>
<svID>Test EPP Server</svID>
<svDate>2024-01-01T00:00:00Z</svDate>
<svcMenu>
<version>1.0</version>
<lang>en</lang>
<objURI>urn:ietf:params:xml:ns:domain-1.0</objURI>
</svcMenu>
</greeting>
</epp>"#
}

async fn connect_with_chunks(num_chunks: usize) -> Result<EppClient<impl Connector>, Error> {
struct FakeConnector {
num_chunks: usize,
}

#[async_trait]
impl Connector for FakeConnector {
type Connection = tokio_test::io::Mock;

async fn connect(&self, _: Duration) -> Result<Self::Connection, Error> {
let mut builder = Builder::new();
let buf = greeting();

builder.read(&len_bytes(buf));

let chunk_size = buf.len() / self.num_chunks;
for i in 0..self.num_chunks {
let start = i * chunk_size;
let end = if i == self.num_chunks - 1 {
buf.len()
} else {
start + chunk_size
};
builder.read(&buf[start..end]);
}

Ok(builder.build())
}
}

EppClient::new(
FakeConnector { num_chunks },
"test".into(),
Duration::from_secs(5),
)
.await
}

#[tokio::test]
async fn greeting_single_chunk() {
assert!(connect_with_chunks(1).await.is_ok());
}

#[tokio::test]
async fn greeting_two_chunks() {
assert!(connect_with_chunks(2).await.is_ok());
}