From cfe58872f0b193c7576c3571ca99a8b47d761030 Mon Sep 17 00:00:00 2001 From: Jack Grigg Date: Wed, 4 Nov 2020 14:29:02 +0000 Subject: [PATCH 1/4] age: Rename *chunk -> *chunks in primitives::stream --- age/src/primitives/stream.rs | 132 +++++++++++++++++++---------------- 1 file changed, 70 insertions(+), 62 deletions(-) diff --git a/age/src/primitives/stream.rs b/age/src/primitives/stream.rs index 54cdf98d..9cb69795 100644 --- a/age/src/primitives/stream.rs +++ b/age/src/primitives/stream.rs @@ -75,7 +75,7 @@ impl Nonce { #[cfg(feature = "async")] #[cfg_attr(docsrs, doc(cfg(feature = "async")))] -struct EncryptedChunk { +struct EncryptedChunks { bytes: Vec, offset: usize, } @@ -111,9 +111,9 @@ impl Stream { StreamWriter { stream: Self::new(key), inner, - chunk: Vec::with_capacity(CHUNK_SIZE), + chunks: Vec::with_capacity(CHUNK_SIZE), #[cfg(feature = "async")] - encrypted_chunk: None, + encrypted_chunks: None, } } @@ -130,8 +130,8 @@ impl Stream { StreamWriter { stream: Self::new(key), inner, - chunk: Vec::with_capacity(CHUNK_SIZE), - encrypted_chunk: None, + chunks: Vec::with_capacity(CHUNK_SIZE), + encrypted_chunks: None, } } @@ -146,12 +146,12 @@ impl Stream { StreamReader { stream: Self::new(key), inner, - encrypted_chunk: vec![0; ENCRYPTED_CHUNK_SIZE], + encrypted_chunks: vec![0; ENCRYPTED_CHUNK_SIZE], encrypted_pos: 0, start: StartPos::Implicit(0), plaintext_len: None, cur_plaintext_pos: 0, - chunk: None, + chunks: None, } } @@ -168,16 +168,18 @@ impl Stream { StreamReader { stream: Self::new(key), inner, - encrypted_chunk: vec![0; ENCRYPTED_CHUNK_SIZE], + encrypted_chunks: vec![0; ENCRYPTED_CHUNK_SIZE], encrypted_pos: 0, start: StartPos::Implicit(0), plaintext_len: None, cur_plaintext_pos: 0, - chunk: None, + chunks: None, } } - fn encrypt_chunk(&mut self, chunk: &[u8], last: bool) -> io::Result> { + fn encrypt_chunks(&mut self, chunks: &[u8], last: bool) -> io::Result> { + // TODO: Generalise to multiple chunks. + let chunk = chunks; assert!(chunk.len() <= CHUNK_SIZE); self.nonce.set_last(last).map_err(|_| { @@ -193,7 +195,9 @@ impl Stream { Ok(encrypted) } - fn decrypt_chunk(&mut self, chunk: &[u8], last: bool) -> io::Result> { + fn decrypt_chunks(&mut self, chunks: &[u8], last: bool) -> io::Result> { + // TODO: Generalise to multiple chunks. + let chunk = chunks; assert!(chunk.len() <= ENCRYPTED_CHUNK_SIZE); self.nonce.set_last(last).map_err(|_| { @@ -224,10 +228,10 @@ pub struct StreamWriter { stream: Stream, #[pin] inner: W, - chunk: Vec, + chunks: Vec, #[cfg(feature = "async")] #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - encrypted_chunk: Option, + encrypted_chunks: Option, } impl StreamWriter { @@ -237,7 +241,7 @@ impl StreamWriter { /// encryption process. Failing to call `finish` will result in a truncated file that /// that will fail to decrypt. pub fn finish(mut self) -> io::Result { - let encrypted = self.stream.encrypt_chunk(&self.chunk, true)?; + let encrypted = self.stream.encrypt_chunks(&self.chunks, true)?; self.inner.write_all(&encrypted)?; Ok(self.inner) } @@ -248,20 +252,20 @@ impl Write for StreamWriter { let mut bytes_written = 0; while !buf.is_empty() { - let to_write = cmp::min(CHUNK_SIZE - self.chunk.len(), buf.len()); - self.chunk.extend_from_slice(&buf[..to_write]); + let to_write = cmp::min(CHUNK_SIZE - self.chunks.len(), buf.len()); + self.chunks.extend_from_slice(&buf[..to_write]); bytes_written += to_write; buf = &buf[to_write..]; // At this point, either buf is empty, or we have a full chunk. - assert!(buf.is_empty() || self.chunk.len() == CHUNK_SIZE); + assert!(buf.is_empty() || self.chunks.len() == CHUNK_SIZE); // Only encrypt the chunk if we have more data to write, as the last // chunk must be written in finish(). if !buf.is_empty() { - let encrypted = self.stream.encrypt_chunk(&self.chunk, false)?; + let encrypted = self.stream.encrypt_chunks(&self.chunks, false)?; self.inner.write_all(&encrypted)?; - self.chunk.clear(); + self.chunks.clear(); } } @@ -279,11 +283,11 @@ impl StreamWriter { fn poll_flush_chunk(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let StreamWriterProj { mut inner, - encrypted_chunk, + encrypted_chunks, .. } = self.project(); - if let Some(chunk) = encrypted_chunk { + if let Some(chunk) = encrypted_chunks { loop { chunk.offset += ready!(inner.as_mut().poll_write(cx, &chunk.bytes[chunk.offset..]))?; @@ -292,7 +296,7 @@ impl StreamWriter { } } } - *encrypted_chunk = None; + *encrypted_chunks = None; Poll::Ready(Ok(())) } @@ -308,26 +312,26 @@ impl AsyncWrite for StreamWriter { ) -> Poll> { ready!(self.as_mut().poll_flush_chunk(cx))?; - let to_write = cmp::min(CHUNK_SIZE - self.chunk.len(), buf.len()); + let to_write = cmp::min(CHUNK_SIZE - self.chunks.len(), buf.len()); self.as_mut() .project() - .chunk + .chunks .extend_from_slice(&buf[..to_write]); buf = &buf[to_write..]; // At this point, either buf is empty, or we have a full chunk. - assert!(buf.is_empty() || self.chunk.len() == CHUNK_SIZE); + assert!(buf.is_empty() || self.chunks.len() == CHUNK_SIZE); // Only encrypt the chunk if we have more data to write, as the last // chunk must be written in poll_close(). if !buf.is_empty() { let this = self.as_mut().project(); - *this.encrypted_chunk = Some(EncryptedChunk { - bytes: this.stream.encrypt_chunk(&this.chunk, false)?, + *this.encrypted_chunks = Some(EncryptedChunks { + bytes: this.stream.encrypt_chunks(&this.chunks, false)?, offset: 0, }); - this.chunk.clear(); + this.chunks.clear(); } Poll::Ready(Ok(to_write)) @@ -345,8 +349,8 @@ impl AsyncWrite for StreamWriter { if !self.stream.is_complete() { // Finish the stream. let this = self.as_mut().project(); - *this.encrypted_chunk = Some(EncryptedChunk { - bytes: this.stream.encrypt_chunk(&this.chunk, true)?, + *this.encrypted_chunks = Some(EncryptedChunks { + bytes: this.stream.encrypt_chunks(&this.chunks, true)?, offset: 0, }); } @@ -378,12 +382,12 @@ pub struct StreamReader { stream: Stream, #[pin] inner: R, - encrypted_chunk: Vec, + encrypted_chunks: Vec, encrypted_pos: usize, start: StartPos, plaintext_len: Option, cur_plaintext_pos: u64, - chunk: Option>, + chunks: Option>, } impl StreamReader { @@ -394,11 +398,11 @@ impl StreamReader { } } - fn decrypt_chunk(&mut self) -> io::Result<()> { + fn decrypt_chunks(&mut self) -> io::Result<()> { self.count_bytes(self.encrypted_pos); - let chunk = &self.encrypted_chunk[..self.encrypted_pos]; + let chunks = &self.encrypted_chunks[..self.encrypted_pos]; - if chunk.is_empty() { + if chunks.is_empty() { if !self.stream.is_complete() { // Stream has ended before seeing the last chunk. return Err(io::Error::new( @@ -410,27 +414,29 @@ impl StreamReader { // This check works for all cases except when the age file is an integer // multiple of the chunk size. In that case, we try decrypting twice on a // decryption failure. - let last = chunk.len() < ENCRYPTED_CHUNK_SIZE; + // TODO: Generalise to multiple chunks. + let last = chunks.len() < ENCRYPTED_CHUNK_SIZE; - self.chunk = match (self.stream.decrypt_chunk(chunk, last), last) { + self.chunks = match (self.stream.decrypt_chunks(chunks, last), last) { (Ok(chunk), _) => Some(chunk), - (Err(_), false) => Some(self.stream.decrypt_chunk(chunk, true)?), + (Err(_), false) => Some(self.stream.decrypt_chunks(chunks, true)?), (Err(e), true) => return Err(e), }; } - // We've finished with this encrypted chunk. + // We've finished with these encrypted chunks. self.encrypted_pos = 0; Ok(()) } - fn read_from_chunk(&mut self, buf: &mut [u8]) -> usize { - if self.chunk.is_none() { + fn read_from_chunks(&mut self, buf: &mut [u8]) -> usize { + if self.chunks.is_none() { return 0; } - let chunk = self.chunk.as_ref().unwrap(); + // TODO: Generalise to multiple chunks. + let chunk = self.chunks.as_ref().unwrap(); let cur_chunk_offset = self.cur_plaintext_pos as usize % CHUNK_SIZE; let to_read = cmp::min(chunk.expose_secret().len() - cur_chunk_offset, buf.len()); @@ -439,8 +445,8 @@ impl StreamReader { .copy_from_slice(&chunk.expose_secret()[cur_chunk_offset..cur_chunk_offset + to_read]); self.cur_plaintext_pos += to_read as u64; if self.cur_plaintext_pos % CHUNK_SIZE as u64 == 0 { - // We've finished with the current chunk. - self.chunk = None; + // We've finished with the current chunks. + self.chunks = None; } to_read @@ -449,11 +455,11 @@ impl StreamReader { impl Read for StreamReader { fn read(&mut self, buf: &mut [u8]) -> io::Result { - if self.chunk.is_none() { + if self.chunks.is_none() { while self.encrypted_pos < ENCRYPTED_CHUNK_SIZE { match self .inner - .read(&mut self.encrypted_chunk[self.encrypted_pos..]) + .read(&mut self.encrypted_chunks[self.encrypted_pos..]) { Ok(0) => break, Ok(n) => self.encrypted_pos += n, @@ -463,10 +469,10 @@ impl Read for StreamReader { }, } } - self.decrypt_chunk()?; + self.decrypt_chunks()?; } - Ok(self.read_from_chunk(buf)) + Ok(self.read_from_chunks(buf)) } } @@ -478,12 +484,12 @@ impl AsyncRead for StreamReader { cx: &mut Context, buf: &mut [u8], ) -> Poll> { - if self.chunk.is_none() { + if self.chunks.is_none() { while self.encrypted_pos < ENCRYPTED_CHUNK_SIZE { let this = self.as_mut().project(); match ready!(this .inner - .poll_read(cx, &mut this.encrypted_chunk[*this.encrypted_pos..])) + .poll_read(cx, &mut this.encrypted_chunks[*this.encrypted_pos..])) { Ok(0) => break, Ok(n) => self.encrypted_pos += n, @@ -493,10 +499,10 @@ impl AsyncRead for StreamReader { }, } } - self.decrypt_chunk()?; + self.decrypt_chunks()?; } - Poll::Ready(Ok(self.read_from_chunk(buf))) + Poll::Ready(Ok(self.read_from_chunks(buf))) } } @@ -539,7 +545,7 @@ impl StreamReader { self.inner.seek(SeekFrom::Start(last_chunk_start))?; self.inner.read_to_end(&mut last_chunk)?; self.stream.nonce.set_counter(num_chunks - 1); - self.stream.decrypt_chunk(&last_chunk, true).map_err(|_| { + self.stream.decrypt_chunks(&last_chunk, true).map_err(|_| { io::Error::new( io::ErrorKind::InvalidData, "Last chunk is invalid, stream might be truncated", @@ -595,6 +601,8 @@ impl Seek for StreamReader { } }; + // TODO: Generalise to multiple chunks. + let cur_chunk_index = self.cur_plaintext_pos / CHUNK_SIZE as u64; let target_chunk_index = target_pos / CHUNK_SIZE as u64; @@ -605,7 +613,7 @@ impl Seek for StreamReader { self.cur_plaintext_pos = target_pos; } else { // Clear the current chunk - self.chunk = None; + self.chunks = None; // Seek to the beginning of the target chunk self.inner.seek(SeekFrom::Start( @@ -663,12 +671,12 @@ mod tests { let encrypted = { let mut s = Stream::new(PayloadKey([7; 32].into())); - s.encrypt_chunk(&data, false).unwrap() + s.encrypt_chunks(&data, false).unwrap() }; let decrypted = { let mut s = Stream::new(PayloadKey([7; 32].into())); - s.decrypt_chunk(&encrypted, false).unwrap() + s.decrypt_chunks(&encrypted, false).unwrap() }; assert_eq!(decrypted.expose_secret(), &data); @@ -680,15 +688,15 @@ mod tests { let encrypted = { let mut s = Stream::new(PayloadKey([7; 32].into())); - let res = s.encrypt_chunk(&data, true).unwrap(); + let res = s.encrypt_chunks(&data, true).unwrap(); // Further calls return an error assert_eq!( - s.encrypt_chunk(&data, false).unwrap_err().kind(), + s.encrypt_chunks(&data, false).unwrap_err().kind(), io::ErrorKind::WriteZero ); assert_eq!( - s.encrypt_chunk(&data, true).unwrap_err().kind(), + s.encrypt_chunks(&data, true).unwrap_err().kind(), io::ErrorKind::WriteZero ); @@ -697,14 +705,14 @@ mod tests { let decrypted = { let mut s = Stream::new(PayloadKey([7; 32].into())); - let res = s.decrypt_chunk(&encrypted, true).unwrap(); + let res = s.decrypt_chunks(&encrypted, true).unwrap(); // Further calls return an error - match s.decrypt_chunk(&encrypted, false) { + match s.decrypt_chunks(&encrypted, false) { Err(e) => assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof), _ => panic!("Expected error"), } - match s.decrypt_chunk(&encrypted, true) { + match s.decrypt_chunks(&encrypted, true) { Err(e) => assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof), _ => panic!("Expected error"), } From 2175081dee5b5f6bd20fb46eec9aa5fff553b116 Mon Sep 17 00:00:00 2001 From: Jack Grigg Date: Wed, 4 Nov 2020 18:25:38 +0000 Subject: [PATCH 2/4] age: Modify Stream::{decrypt, encrypt}_chunks to support multiple chunks --- age/src/primitives/stream.rs | 88 +++++++++++++++++++++++++----------- 1 file changed, 61 insertions(+), 27 deletions(-) diff --git a/age/src/primitives/stream.rs b/age/src/primitives/stream.rs index 9cb69795..a2bebbff 100644 --- a/age/src/primitives/stream.rs +++ b/age/src/primitives/stream.rs @@ -2,7 +2,7 @@ use age_core::secrecy::{ExposeSecret, SecretVec}; use chacha20poly1305::{ - aead::{generic_array::GenericArray, Aead, NewAead}, + aead::{generic_array::GenericArray, AeadInPlace, NewAead}, ChaCha20Poly1305, }; use pin_project::pin_project; @@ -178,43 +178,77 @@ impl Stream { } fn encrypt_chunks(&mut self, chunks: &[u8], last: bool) -> io::Result> { - // TODO: Generalise to multiple chunks. - let chunk = chunks; - assert!(chunk.len() <= CHUNK_SIZE); + if self.nonce.is_last() { + Err(io::Error::new( + io::ErrorKind::WriteZero, + "last chunk has been processed", + ))?; + }; + + // Allocate an output buffer of the correct length. + let chunks_len = chunks.len(); + let chunks = chunks.chunks(CHUNK_SIZE); + let num_chunks = chunks.len(); + let mut encrypted = vec![0; chunks_len + TAG_SIZE * num_chunks]; - self.nonce.set_last(last).map_err(|_| { - io::Error::new(io::ErrorKind::WriteZero, "last chunk has been processed") - })?; + for (i, (encrypted, chunk)) in encrypted + .chunks_mut(ENCRYPTED_CHUNK_SIZE) + .zip(chunks) + .enumerate() + { + if i + 1 == num_chunks { + self.nonce.set_last(last).unwrap(); + } + + let (buffer, tag) = encrypted.split_at_mut(chunk.len()); + buffer.copy_from_slice(chunk); + tag.copy_from_slice( + self.aead + .encrypt_in_place_detached(&self.nonce.to_bytes().into(), &[], buffer) + .expect("we will never hit chacha20::MAX_BLOCKS because of the chunk size") + .as_slice(), + ); - let encrypted = self - .aead - .encrypt(&self.nonce.to_bytes().into(), chunk) - .expect("we will never hit chacha20::MAX_BLOCKS because of the chunk size"); - self.nonce.increment_counter(); + self.nonce.increment_counter(); + } Ok(encrypted) } fn decrypt_chunks(&mut self, chunks: &[u8], last: bool) -> io::Result> { - // TODO: Generalise to multiple chunks. - let chunk = chunks; - assert!(chunk.len() <= ENCRYPTED_CHUNK_SIZE); - - self.nonce.set_last(last).map_err(|_| { - io::Error::new( + if self.nonce.is_last() { + Err(io::Error::new( io::ErrorKind::UnexpectedEof, "last chunk has been processed", - ) - })?; + ))?; + }; - let decrypted = self - .aead - .decrypt(&self.nonce.to_bytes().into(), chunk) - .map(SecretVec::new) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "decryption error"))?; - self.nonce.increment_counter(); + // Allocate an output buffer of the correct length. + let chunks_len = chunks.len(); + let chunks = chunks.chunks(ENCRYPTED_CHUNK_SIZE); + let num_chunks = chunks.len(); + let mut decrypted = vec![0; chunks_len - TAG_SIZE * num_chunks]; + + for (i, (decrypted, chunk)) in decrypted.chunks_mut(CHUNK_SIZE).zip(chunks).enumerate() { + if i + 1 == num_chunks { + self.nonce.set_last(last).unwrap(); + } + + let (chunk, tag) = chunk.split_at(decrypted.len()); + decrypted.copy_from_slice(chunk); + self.aead + .decrypt_in_place_detached( + &self.nonce.to_bytes().into(), + &[], + decrypted, + tag.into(), + ) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "decryption error"))?; + + self.nonce.increment_counter(); + } - Ok(decrypted) + Ok(SecretVec::new(decrypted)) } fn is_complete(&self) -> bool { From eea64b8cd2010e3240f742352c512c87913c8e73 Mon Sep 17 00:00:00 2001 From: Jack Grigg Date: Wed, 4 Nov 2020 20:10:48 +0000 Subject: [PATCH 3/4] age: Buffer as many STREAM chunks as we have logical CPUs --- Cargo.lock | 1 + age/Cargo.toml | 3 ++ age/src/primitives/stream.rs | 58 +++++++++++++++++++----------------- 3 files changed, 35 insertions(+), 27 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1dd77239..ae664853 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -68,6 +68,7 @@ dependencies = [ "lazy_static", "nom", "num-traits", + "num_cpus", "pin-project", "pinentry", "pprof", diff --git a/age/Cargo.toml b/age/Cargo.toml index 26371234..a1eee2ba 100644 --- a/age/Cargo.toml +++ b/age/Cargo.toml @@ -73,6 +73,9 @@ i18n-embed-fl = "0.6" lazy_static = "1" rust-embed = "6" +# Performance +num_cpus = "1.0" + # Common CLI dependencies console = { version = "0.15", optional = true, default-features = false } pinentry = { version = "0.5", optional = true } diff --git a/age/src/primitives/stream.rs b/age/src/primitives/stream.rs index a2bebbff..19f4ec48 100644 --- a/age/src/primitives/stream.rs +++ b/age/src/primitives/stream.rs @@ -5,6 +5,7 @@ use chacha20poly1305::{ aead::{generic_array::GenericArray, AeadInPlace, NewAead}, ChaCha20Poly1305, }; +use lazy_static::lazy_static; use pin_project::pin_project; use std::cmp; use std::convert::TryInto; @@ -24,6 +25,11 @@ const CHUNK_SIZE: usize = 64 * 1024; const TAG_SIZE: usize = 16; const ENCRYPTED_CHUNK_SIZE: usize = CHUNK_SIZE + TAG_SIZE; +lazy_static! { + static ref CHUNKS_SIZE: usize = num_cpus::get() * CHUNK_SIZE; + static ref ENCRYPTED_CHUNKS_SIZE: usize = num_cpus::get() * ENCRYPTED_CHUNK_SIZE; +} + pub(crate) struct PayloadKey(pub(crate) GenericArray::KeySize>); impl Drop for PayloadKey { @@ -111,7 +117,7 @@ impl Stream { StreamWriter { stream: Self::new(key), inner, - chunks: Vec::with_capacity(CHUNK_SIZE), + chunks: Vec::with_capacity(*CHUNKS_SIZE), #[cfg(feature = "async")] encrypted_chunks: None, } @@ -130,7 +136,7 @@ impl Stream { StreamWriter { stream: Self::new(key), inner, - chunks: Vec::with_capacity(CHUNK_SIZE), + chunks: Vec::with_capacity(*CHUNKS_SIZE), encrypted_chunks: None, } } @@ -146,7 +152,7 @@ impl Stream { StreamReader { stream: Self::new(key), inner, - encrypted_chunks: vec![0; ENCRYPTED_CHUNK_SIZE], + encrypted_chunks: vec![0; *ENCRYPTED_CHUNKS_SIZE], encrypted_pos: 0, start: StartPos::Implicit(0), plaintext_len: None, @@ -168,7 +174,7 @@ impl Stream { StreamReader { stream: Self::new(key), inner, - encrypted_chunks: vec![0; ENCRYPTED_CHUNK_SIZE], + encrypted_chunks: vec![0; *ENCRYPTED_CHUNKS_SIZE], encrypted_pos: 0, start: StartPos::Implicit(0), plaintext_len: None, @@ -286,13 +292,13 @@ impl Write for StreamWriter { let mut bytes_written = 0; while !buf.is_empty() { - let to_write = cmp::min(CHUNK_SIZE - self.chunks.len(), buf.len()); + let to_write = cmp::min(*CHUNKS_SIZE - self.chunks.len(), buf.len()); self.chunks.extend_from_slice(&buf[..to_write]); bytes_written += to_write; buf = &buf[to_write..]; - // At this point, either buf is empty, or we have a full chunk. - assert!(buf.is_empty() || self.chunks.len() == CHUNK_SIZE); + // At this point, either buf is empty, or we have a full set of chunks. + assert!(buf.is_empty() || self.chunks.len() == *CHUNKS_SIZE); // Only encrypt the chunk if we have more data to write, as the last // chunk must be written in finish(). @@ -346,7 +352,7 @@ impl AsyncWrite for StreamWriter { ) -> Poll> { ready!(self.as_mut().poll_flush_chunk(cx))?; - let to_write = cmp::min(CHUNK_SIZE - self.chunks.len(), buf.len()); + let to_write = cmp::min(*CHUNKS_SIZE - self.chunks.len(), buf.len()); self.as_mut() .project() @@ -354,8 +360,8 @@ impl AsyncWrite for StreamWriter { .extend_from_slice(&buf[..to_write]); buf = &buf[to_write..]; - // At this point, either buf is empty, or we have a full chunk. - assert!(buf.is_empty() || self.chunks.len() == CHUNK_SIZE); + // At this point, either buf is empty, or we have a full set of chunks. + assert!(buf.is_empty() || self.chunks.len() == *CHUNKS_SIZE); // Only encrypt the chunk if we have more data to write, as the last // chunk must be written in poll_close(). @@ -449,7 +455,7 @@ impl StreamReader { // multiple of the chunk size. In that case, we try decrypting twice on a // decryption failure. // TODO: Generalise to multiple chunks. - let last = chunks.len() < ENCRYPTED_CHUNK_SIZE; + let last = chunks.len() < *ENCRYPTED_CHUNKS_SIZE; self.chunks = match (self.stream.decrypt_chunks(chunks, last), last) { (Ok(chunk), _) => Some(chunk), @@ -469,16 +475,16 @@ impl StreamReader { return 0; } - // TODO: Generalise to multiple chunks. - let chunk = self.chunks.as_ref().unwrap(); - let cur_chunk_offset = self.cur_plaintext_pos as usize % CHUNK_SIZE; + let chunks = self.chunks.as_ref().unwrap(); + let cur_chunks_offset = self.cur_plaintext_pos as usize % *CHUNKS_SIZE; - let to_read = cmp::min(chunk.expose_secret().len() - cur_chunk_offset, buf.len()); + let to_read = cmp::min(chunks.expose_secret().len() - cur_chunks_offset, buf.len()); - buf[..to_read] - .copy_from_slice(&chunk.expose_secret()[cur_chunk_offset..cur_chunk_offset + to_read]); + buf[..to_read].copy_from_slice( + &chunks.expose_secret()[cur_chunks_offset..cur_chunks_offset + to_read], + ); self.cur_plaintext_pos += to_read as u64; - if self.cur_plaintext_pos % CHUNK_SIZE as u64 == 0 { + if self.cur_plaintext_pos % *CHUNKS_SIZE as u64 == 0 { // We've finished with the current chunks. self.chunks = None; } @@ -490,7 +496,7 @@ impl StreamReader { impl Read for StreamReader { fn read(&mut self, buf: &mut [u8]) -> io::Result { if self.chunks.is_none() { - while self.encrypted_pos < ENCRYPTED_CHUNK_SIZE { + while self.encrypted_pos < *ENCRYPTED_CHUNKS_SIZE { match self .inner .read(&mut self.encrypted_chunks[self.encrypted_pos..]) @@ -519,7 +525,7 @@ impl AsyncRead for StreamReader { buf: &mut [u8], ) -> Poll> { if self.chunks.is_none() { - while self.encrypted_pos < ENCRYPTED_CHUNK_SIZE { + while self.encrypted_pos < *ENCRYPTED_CHUNKS_SIZE { let this = self.as_mut().project(); match ready!(this .inner @@ -635,12 +641,10 @@ impl Seek for StreamReader { } }; - // TODO: Generalise to multiple chunks. - - let cur_chunk_index = self.cur_plaintext_pos / CHUNK_SIZE as u64; + let cur_chunk_index = self.cur_plaintext_pos / *CHUNKS_SIZE as u64; - let target_chunk_index = target_pos / CHUNK_SIZE as u64; - let target_chunk_offset = target_pos % CHUNK_SIZE as u64; + let target_chunk_index = target_pos / *CHUNKS_SIZE as u64; + let target_chunk_offset = target_pos % *CHUNKS_SIZE as u64; if target_chunk_index == cur_chunk_index { // We just need to reposition ourselves within the current chunk. @@ -651,10 +655,10 @@ impl Seek for StreamReader { // Seek to the beginning of the target chunk self.inner.seek(SeekFrom::Start( - start + (target_chunk_index * ENCRYPTED_CHUNK_SIZE as u64), + start + (target_chunk_index * *ENCRYPTED_CHUNKS_SIZE as u64), ))?; self.stream.nonce.set_counter(target_chunk_index); - self.cur_plaintext_pos = target_chunk_index * CHUNK_SIZE as u64; + self.cur_plaintext_pos = target_chunk_index * *CHUNKS_SIZE as u64; // Read and drop bytes from the chunk to reach the target position. if target_chunk_offset > 0 { From 4f72ab0eca9617f0df64190535e515bd275e7ffa Mon Sep 17 00:00:00 2001 From: Jack Grigg Date: Wed, 4 Nov 2020 21:08:45 +0000 Subject: [PATCH 4/4] age: Use rayon for processing STREAM chunks in parallel --- Cargo.lock | 1 + age/Cargo.toml | 1 + age/src/primitives/stream.rs | 40 ++++++++++++++++++++---------------- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ae664853..fde403f8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -76,6 +76,7 @@ dependencies = [ "quickcheck_macros", "rand 0.7.3", "rand 0.8.4", + "rayon", "rpassword", "rsa", "rust-embed", diff --git a/age/Cargo.toml b/age/Cargo.toml index a1eee2ba..b7dcf889 100644 --- a/age/Cargo.toml +++ b/age/Cargo.toml @@ -75,6 +75,7 @@ rust-embed = "6" # Performance num_cpus = "1.0" +rayon = "1.5" # Common CLI dependencies console = { version = "0.15", optional = true, default-features = false } diff --git a/age/src/primitives/stream.rs b/age/src/primitives/stream.rs index 19f4ec48..d70ae714 100644 --- a/age/src/primitives/stream.rs +++ b/age/src/primitives/stream.rs @@ -7,6 +7,7 @@ use chacha20poly1305::{ }; use lazy_static::lazy_static; use pin_project::pin_project; +use rayon::prelude::*; use std::cmp; use std::convert::TryInto; use std::io::{self, Read, Seek, SeekFrom, Write}; @@ -51,9 +52,9 @@ impl Nonce { self.0 = u128::from(val) << 8; } - fn increment_counter(&mut self) { + fn increment_counter(&mut self, by: usize) { // Increment the 11-byte counter - self.0 += 1 << 8; + self.0 += (by as u128) << 8; if self.0 >> (8 * 12) != 0 { panic!("We overflowed the nonce!"); } @@ -197,26 +198,29 @@ impl Stream { let num_chunks = chunks.len(); let mut encrypted = vec![0; chunks_len + TAG_SIZE * num_chunks]; - for (i, (encrypted, chunk)) in encrypted + encrypted .chunks_mut(ENCRYPTED_CHUNK_SIZE) .zip(chunks) .enumerate() - { - if i + 1 == num_chunks { - self.nonce.set_last(last).unwrap(); - } + .par_bridge() + .for_each_with(self.nonce, |nonce, (i, (encrypted, chunk))| { + nonce.increment_counter(i); + if i + 1 == num_chunks { + nonce.set_last(last).unwrap(); + } - let (buffer, tag) = encrypted.split_at_mut(chunk.len()); - buffer.copy_from_slice(chunk); - tag.copy_from_slice( - self.aead - .encrypt_in_place_detached(&self.nonce.to_bytes().into(), &[], buffer) - .expect("we will never hit chacha20::MAX_BLOCKS because of the chunk size") - .as_slice(), - ); + let (buffer, tag) = encrypted.split_at_mut(chunk.len()); + buffer.copy_from_slice(chunk); + tag.copy_from_slice( + self.aead + .encrypt_in_place_detached(&nonce.to_bytes().into(), &[], buffer) + .expect("we will never hit chacha20::MAX_BLOCKS because of the chunk size") + .as_slice(), + ); + }); - self.nonce.increment_counter(); - } + self.nonce.increment_counter(num_chunks); + self.nonce.set_last(last).unwrap(); Ok(encrypted) } @@ -251,7 +255,7 @@ impl Stream { ) .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "decryption error"))?; - self.nonce.increment_counter(); + self.nonce.increment_counter(1); } Ok(SecretVec::new(decrypted))