diff --git a/nativelink-store/src/fast_slow_store.rs b/nativelink-store/src/fast_slow_store.rs index 459710683..ca8af2c00 100644 --- a/nativelink-store/src/fast_slow_store.rs +++ b/nativelink-store/src/fast_slow_store.rs @@ -17,6 +17,7 @@ use core::cmp::{max, min}; use core::ops::Range; use core::pin::Pin; use core::sync::atomic::{AtomicU64, Ordering}; +use core::time::Duration; use std::collections::HashMap; use std::ffi::OsString; use std::sync::{Arc, Weak}; @@ -63,6 +64,9 @@ pub struct FastSlowStore { // actually it's faster because we're not downloading the file multiple // times are doing loads of duplicate IO. populating_digests: Mutex, Loader>>, + // The amount of time to allow stores to start before determining that they + // have deadlocked and retrying. + deadlock_timeout: Duration, } // This guard ensures that the populating_digests is cleared even if the future @@ -116,6 +120,15 @@ impl Drop for LoaderGuard<'_> { impl FastSlowStore { pub fn new(spec: &FastSlowSpec, fast_store: Store, slow_store: Store) -> Arc { + Self::new_with_deadlock_timeout(spec, fast_store, slow_store, Duration::from_secs(5)) + } + + pub fn new_with_deadlock_timeout( + spec: &FastSlowSpec, + fast_store: Store, + slow_store: Store, + deadlock_timeout: Duration, + ) -> Arc { Arc::new_cyclic(|weak_self| Self { fast_store, fast_direction: spec.fast_direction, @@ -124,6 +137,7 @@ impl FastSlowStore { weak_self: weak_self.clone(), metrics: FastSlowStoreMetrics::default(), populating_digests: Mutex::new(HashMap::new()), + deadlock_timeout, }) } @@ -189,8 +203,62 @@ impl FastSlowStore { let send_range = offset..length.map_or(u64::MAX, |length| length + offset); let mut bytes_received: u64 = 0; - let (mut fast_tx, fast_rx) = make_buf_channel_pair(); - let (slow_tx, mut slow_rx) = make_buf_channel_pair(); + // There's a strong possibility of a deadlock here as we're working with multiple + // stores. We need to be careful that we don't hold a get semaphore if we can't + // open the update. This doesn't know anything about the downstream implementations, + // so simply makes use of a timeout to check that the reader and writers are set up. + let (stores_fut, mut slow_rx, mut fast_tx) = loop { + let (mut fast_tx, fast_rx) = make_buf_channel_pair(); + let (slow_tx, mut slow_rx) = make_buf_channel_pair(); + + let slow_store_fut = self.slow_store.get(key.borrow(), slow_tx); + let fast_store_fut = + self.fast_store + .update(key.borrow(), fast_rx, UploadSizeInfo::ExactSize(sz)); + let mut stores_fut = futures::future::join(slow_store_fut, fast_store_fut); + let has_semaphores_fut = tokio::time::timeout( + self.deadlock_timeout, + futures::future::join(slow_rx.peek(), fast_tx.is_waiting()), + ); + tokio::select! { + result = &mut stores_fut => { + match result { + (Ok(()), Ok(())) => { + // Both stores completed without the writers, probably zero byte. + return Ok(()); + } + (Ok(()), Err(err)) | (Err(err), Ok(())) => { + return Err(err); + } + (Err(err1), Err(err2)) => { + return Err(err1.merge(err2)); + } + } + } + result = has_semaphores_fut => { + match result { + Ok((Ok(_), Ok(()))) => { + // Both sides have started reading/writing, we assume they hold + // all the permits they require and it's safe to continue. + break (stores_fut, slow_rx, fast_tx); + } + Ok((Ok(_), Err(err)) | (Err(err), Ok(()))) => { + return Err(err); + } + Ok((Err(err1), Err(err2))) => { + return Err(err1.merge(err2)); + } + Err(_timeout) => { + // There was probably a deadlock... we need to drop and try again. + drop(stores_fut); + tracing::warn!("Possible deadlock in fast-slow, retrying."); + tokio::time::sleep(Duration::from_millis(10)).await; + } + } + } + }; + }; + let data_stream_fut = async move { let mut maybe_writer_pin = maybe_writer.map(Pin::new); loop { @@ -229,13 +297,7 @@ impl FastSlowStore { } }; - let slow_store_fut = self.slow_store.get(key.borrow(), slow_tx); - let fast_store_fut = - self.fast_store - .update(key.borrow(), fast_rx, UploadSizeInfo::ExactSize(sz)); - - let (data_stream_res, slow_res, fast_res) = - join!(data_stream_fut, slow_store_fut, fast_store_fut); + let (data_stream_res, (slow_res, fast_res)) = join!(data_stream_fut, stores_fut); match data_stream_res { Ok((fast_eof_res, maybe_writer_pin)) => // Sending the EOF will drop us almost immediately in bytestream_server diff --git a/nativelink-store/src/gcs_client/client.rs b/nativelink-store/src/gcs_client/client.rs index 664ec2114..d6cc4247d 100644 --- a/nativelink-store/src/gcs_client/client.rs +++ b/nativelink-store/src/gcs_client/client.rs @@ -41,6 +41,12 @@ use crate::gcs_client::types::{ SIMPLE_UPLOAD_THRESHOLD, Timestamp, }; +#[derive(Debug)] +pub struct UploadRef { + pub upload_ref: String, + pub(crate) _permit: OwnedSemaphorePermit, +} + /// A trait that defines the required GCS operations. /// This abstraction allows for easier testing by mocking GCS responses. pub trait GcsOperations: Send + Sync + Debug { @@ -71,12 +77,12 @@ pub trait GcsOperations: Send + Sync + Debug { fn start_resumable_write( &self, object_path: &ObjectPath, - ) -> impl Future> + Send; + ) -> impl Future> + Send; /// Upload a chunk of data in a resumable upload session fn upload_chunk( &self, - upload_url: &str, + upload_url: &UploadRef, object_path: &ObjectPath, data: Bytes, offset: u64, @@ -336,12 +342,21 @@ impl GcsClient { } // Check if the object exists - match self.read_object_metadata(object_path).await? { - Some(_) => Ok(()), - None => Err(make_err!( - Code::Internal, - "Upload completed but object not found" - )), + let request = GetObjectRequest { + bucket: object_path.bucket.clone(), + object: object_path.path.clone(), + ..Default::default() + }; + + match self.client.get_object(&request).await { + Ok(_) => Ok(()), + Err(GcsError::Response(resp)) if resp.code == 404 => { + return Err(make_err!( + Code::Internal, + "Upload completed but object not found" + )); + } + Err(err) => Err(Self::handle_gcs_error(&err)), } }) .await @@ -470,55 +485,58 @@ impl GcsOperations for GcsClient { .await } - async fn start_resumable_write(&self, object_path: &ObjectPath) -> Result { - self.with_connection(|| async { - let request = UploadObjectRequest { - bucket: object_path.bucket.clone(), - ..Default::default() - }; + async fn start_resumable_write(&self, object_path: &ObjectPath) -> Result { + let permit = + self.semaphore.clone().acquire_owned().await.map_err(|e| { + make_err!(Code::Internal, "Failed to acquire connection permit: {}", e) + })?; + let request = UploadObjectRequest { + bucket: object_path.bucket.clone(), + ..Default::default() + }; - let upload_type = UploadType::Multipart(Box::new(Object { - name: object_path.path.clone(), - content_type: Some(DEFAULT_CONTENT_TYPE.to_string()), - ..Default::default() - })); + let upload_type = UploadType::Multipart(Box::new(Object { + name: object_path.path.clone(), + content_type: Some(DEFAULT_CONTENT_TYPE.to_string()), + ..Default::default() + })); - // Start resumable upload session - let uploader = self - .client - .prepare_resumable_upload(&request, &upload_type) - .await - .map_err(|e| Self::handle_gcs_error(&e))?; + // Start resumable upload session + let uploader = self + .client + .prepare_resumable_upload(&request, &upload_type) + .await + .map_err(|e| Self::handle_gcs_error(&e))?; - Ok(uploader.url().to_string()) + Ok(UploadRef { + upload_ref: uploader.url().to_string(), + _permit: permit, }) - .await } async fn upload_chunk( &self, - upload_url: &str, + upload_url: &UploadRef, _object_path: &ObjectPath, data: Bytes, offset: u64, end_offset: u64, total_size: Option, ) -> Result<(), Error> { - self.with_connection(|| async { - let uploader = self.client.get_resumable_upload(upload_url.to_string()); + let uploader = self + .client + .get_resumable_upload(upload_url.upload_ref.clone()); - let last_byte = if end_offset == 0 { 0 } else { end_offset - 1 }; - let chunk_def = ChunkSize::new(offset, last_byte, total_size); + let last_byte = if end_offset == 0 { 0 } else { end_offset - 1 }; + let chunk_def = ChunkSize::new(offset, last_byte, total_size); - // Upload chunk - uploader - .upload_multiple_chunk(data, &chunk_def) - .await - .map_err(|e| Self::handle_gcs_error(&e))?; + // Upload chunk + uploader + .upload_multiple_chunk(data, &chunk_def) + .await + .map_err(|e| Self::handle_gcs_error(&e))?; - Ok(()) - }) - .await + Ok(()) } async fn upload_from_reader( diff --git a/nativelink-store/src/gcs_client/mocks.rs b/nativelink-store/src/gcs_client/mocks.rs index d35dac758..f53837ff0 100644 --- a/nativelink-store/src/gcs_client/mocks.rs +++ b/nativelink-store/src/gcs_client/mocks.rs @@ -15,15 +15,16 @@ use core::fmt::Debug; use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::collections::HashMap; +use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; use bytes::Bytes; use futures::Stream; use nativelink_error::{Code, Error, make_err}; use nativelink_util::buf_channel::DropCloserReadHalf; -use tokio::sync::RwLock; +use tokio::sync::{RwLock, Semaphore}; -use crate::gcs_client::client::GcsOperations; +use crate::gcs_client::client::{GcsOperations, UploadRef}; use crate::gcs_client::types::{DEFAULT_CONTENT_TYPE, GcsObject, ObjectPath, Timestamp}; /// A mock implementation of `GcsOperations` for testing @@ -379,7 +380,7 @@ impl GcsOperations for MockGcsOperations { Ok(()) } - async fn start_resumable_write(&self, object_path: &ObjectPath) -> Result { + async fn start_resumable_write(&self, object_path: &ObjectPath) -> Result { self.call_counts .start_resumable_calls .fetch_add(1, Ordering::Relaxed); @@ -392,12 +393,15 @@ impl GcsOperations for MockGcsOperations { self.handle_failure().await?; let upload_id = format!("mock-upload-{}-{}", object_path.bucket, object_path.path); - Ok(upload_id) + Ok(UploadRef { + upload_ref: upload_id, + _permit: Arc::new(Semaphore::new(1)).acquire_owned().await.unwrap(), + }) } async fn upload_chunk( &self, - upload_url: &str, + upload_url: &UploadRef, object_path: &ObjectPath, data: Bytes, offset: u64, @@ -408,7 +412,7 @@ impl GcsOperations for MockGcsOperations { .upload_chunk_calls .fetch_add(1, Ordering::Relaxed); self.requests.write().await.push(MockRequest::UploadChunk { - upload_url: upload_url.to_string(), + upload_url: upload_url.upload_ref.clone(), object_path: object_path.clone(), data_len: data.len(), offset, diff --git a/nativelink-store/src/gcs_store.rs b/nativelink-store/src/gcs_store.rs index 898aa8b09..2a971e6df 100644 --- a/nativelink-store/src/gcs_store.rs +++ b/nativelink-store/src/gcs_store.rs @@ -270,8 +270,23 @@ where } else { None }; - let mut upload_id: Option = None; let client = &self.client; + let upload_id = self + .retrier + .retry(unfold((), |()| async { + match client.start_resumable_write(&object_path).await { + Ok(id) => Some((RetryResult::Ok(id), ())), + Err(e) => Some(( + RetryResult::Retry(make_err!( + Code::Aborted, + "Failed to start resumable upload: {:?}", + e + )), + (), + )), + } + })) + .await?; loop { let chunk = reader.consume(Some(self.max_chunk_size)).await?; @@ -283,35 +298,12 @@ where total_size = Some(offset + chunk.len() as u64); } - let upload_id_ref = if let Some(upload_id_ref) = &upload_id { - upload_id_ref - } else { - // Initiate the upload session on the first non-empty chunk. - upload_id = Some( - self.retrier - .retry(unfold((), |()| async { - match client.start_resumable_write(&object_path).await { - Ok(id) => Some((RetryResult::Ok(id), ())), - Err(e) => Some(( - RetryResult::Retry(make_err!( - Code::Aborted, - "Failed to start resumable upload: {:?}", - e - )), - (), - )), - } - })) - .await?, - ); - upload_id.as_deref().unwrap() - }; - let current_offset = offset; offset += chunk.len() as u64; // Uploading the chunk with a retry let object_path_ref = &object_path; + let upload_id_ref = &upload_id; self.retrier .retry(unfold(chunk, |chunk| async move { match client @@ -334,40 +326,30 @@ where // Handle the case that the stream was of unknown length and // happened to be an exact multiple of chunk size. - if let Some(upload_id_ref) = &upload_id { - if total_size.is_none() { - let object_path_ref = &object_path; - self.retrier - .retry(unfold((), |()| async move { - match client - .upload_chunk( - upload_id_ref, - object_path_ref, - Bytes::new(), - offset, - offset, - Some(offset), - ) - .await - { - Ok(()) => Some((RetryResult::Ok(()), ())), - Err(e) => Some((RetryResult::Retry(e), ())), - } - })) - .await?; - } - } else { - // Handle streamed empty file. - return self - .retrier - .retry(unfold((), |()| async { - match client.write_object(&object_path, Vec::new()).await { + if total_size.is_none() { + let object_path_ref = &object_path; + let upload_id_ref = &upload_id; + self.retrier + .retry(unfold((), |()| async move { + match client + .upload_chunk( + upload_id_ref, + object_path_ref, + Bytes::new(), + offset, + offset, + Some(offset), + ) + .await + { Ok(()) => Some((RetryResult::Ok(()), ())), Err(e) => Some((RetryResult::Retry(e), ())), } })) - .await; + .await?; } + // Ensure we drop the permit before verifying. + drop(upload_id); // Verifying if the upload was successful self.retrier diff --git a/nativelink-store/tests/fast_slow_store_test.rs b/nativelink-store/tests/fast_slow_store_test.rs index 73894cc59..652f74676 100644 --- a/nativelink-store/tests/fast_slow_store_test.rs +++ b/nativelink-store/tests/fast_slow_store_test.rs @@ -570,3 +570,188 @@ async fn fast_readonly_only_not_updated_on_get() -> Result<(), Error> { ); Ok(()) } + +#[derive(MetricsComponent)] +struct SemaphoreStore { + sem: Arc, + inner: Arc, +} + +impl SemaphoreStore { + fn new(sem: Arc) -> Arc { + Arc::new(Self { + sem, + inner: MemoryStore::new(&MemorySpec::default()), + }) + } + + async fn get_permit(&self) -> Result, Error> { + self.sem + .acquire() + .await + .map_err(|e| make_err!(Code::Internal, "Failed to acquire permit: {e:?}")) + } +} + +#[async_trait] +impl StoreDriver for SemaphoreStore { + async fn get_part( + self: Pin<&Self>, + key: StoreKey<'_>, + writer: &mut nativelink_util::buf_channel::DropCloserWriteHalf, + offset: u64, + length: Option, + ) -> Result<(), Error> { + let _guard = self.get_permit().await?; + // Ensure this isn't returned in two or fewer writes as that is the buffer size. + let (second_writer, mut second_reader) = make_buf_channel_pair(); + let write_fut = async move { + let data = second_reader.recv().await?; + if data.len() > 6 { + writer.send(data.slice(0..1)).await?; + writer.send(data.slice(1..2)).await?; + writer.send(data.slice(2..3)).await?; + writer.send(data.slice(3..4)).await?; + writer.send(data.slice(4..5)).await?; + writer.send(data.slice(5..)).await?; + } else { + writer.send(data).await?; + } + loop { + let data = second_reader.recv().await?; + if data.is_empty() { + break; + } + writer.send(data).await?; + } + writer.send_eof() + }; + let (res1, res2) = tokio::join!( + write_fut, + self.inner.get_part(key, second_writer, offset, length) + ); + res1.merge(res2) + } + + async fn has_with_results( + self: Pin<&Self>, + digests: &[StoreKey<'_>], + results: &mut [Option], + ) -> Result<(), Error> { + let _guard = self.get_permit().await?; + self.inner.has_with_results(digests, results).await + } + + async fn update( + self: Pin<&Self>, + key: StoreKey<'_>, + mut reader: nativelink_util::buf_channel::DropCloserReadHalf, + upload_size: nativelink_util::store_trait::UploadSizeInfo, + ) -> Result<(), Error> { + let _guard = self.get_permit().await?; + let (mut second_writer, second_reader) = make_buf_channel_pair(); + let write_fut = async move { + let data = reader.recv().await?; + if data.len() > 6 { + // We have two buffers each with two in so we have to chunk to cause a lock up. + second_writer.send(data.slice(0..1)).await?; + second_writer.send(data.slice(1..2)).await?; + second_writer.send(data.slice(2..3)).await?; + second_writer.send(data.slice(3..4)).await?; + second_writer.send(data.slice(4..5)).await?; + second_writer.send(data.slice(5..)).await?; + } else { + second_writer.send(data).await?; + } + loop { + let data = reader.recv().await?; + if data.is_empty() { + break; + } + second_writer.send(data).await?; + } + second_writer.send_eof() + }; + let (res1, res2) = tokio::join!( + write_fut, + self.inner.update(key, second_reader, upload_size) + ); + res1.merge(res2) + } + + fn inner_store(&self, _digest: Option>) -> &dyn StoreDriver { + self + } + + fn as_any(&self) -> &(dyn core::any::Any + Sync + Send + 'static) { + self + } + + fn as_any_arc(self: Arc) -> Arc { + self + } + + fn register_remove_callback( + self: Arc, + callback: Arc, + ) -> Result<(), Error> { + self.inner.clone().register_remove_callback(callback) + } +} + +default_health_status_indicator!(SemaphoreStore); + +#[nativelink_test] +async fn semaphore_deadlocks_handled() -> Result<(), Error> { + // Just enough semaphores for the action to function, one for each store. + let semaphore = Arc::new(tokio::sync::Semaphore::new(2)); + let fast_store = Store::new(SemaphoreStore::new(semaphore.clone())); + let slow_store = Store::new(SemaphoreStore::new(semaphore.clone())); + let fast_slow_store_config = FastSlowSpec { + fast: StoreSpec::Memory(MemorySpec::default()), + slow: StoreSpec::Noop(NoopSpec::default()), + fast_direction: StoreDirection::default(), + slow_direction: StoreDirection::default(), + }; + let fast_slow_store = Arc::new(FastSlowStore::new_with_deadlock_timeout( + &fast_slow_store_config, + fast_store.clone(), + slow_store.clone(), + core::time::Duration::from_secs(1), + )); + + let data = make_random_data(100); + let digest = DigestInfo::try_new(VALID_HASH, data.len()).unwrap(); + + // Upload some dummy data to the slow store. + slow_store + .update_oneshot(digest, data.clone().into()) + .await?; + + // Now try to get it back without a permit, this should deadlock. We release the + // semaphore when it's released from the other store. + let guard = semaphore.clone().acquire_owned().await.unwrap(); + let release_fut = async move { + // Wait for the store to get the last permit. + while semaphore.available_permits() > 0 { + tokio::time::sleep(core::time::Duration::from_millis(10)).await; + } + // Now wait for it to be released. + let _second_guard = semaphore.acquire().await.unwrap(); + // Now release all the permits. + drop(guard); + }; + let (_, result) = tokio::join!( + release_fut, + tokio::time::timeout( + core::time::Duration::from_secs(10), + fast_slow_store.get_part_unchunked(digest, 0, None) + ) + ); + assert_eq!( + result.map_err(|_| make_err!(Code::Internal, "Semaphore deadlock"))?, + Ok(data.into()) + ); + + Ok(()) +} diff --git a/nativelink-store/tests/gcs_client_test.rs b/nativelink-store/tests/gcs_client_test.rs index 22b4dd30e..901c6b74d 100644 --- a/nativelink-store/tests/gcs_client_test.rs +++ b/nativelink-store/tests/gcs_client_test.rs @@ -171,7 +171,10 @@ async fn test_resumable_upload() -> Result<(), Error> { // Start a resumable upload let upload_id = mock_ops.start_resumable_write(&object_path).await?; - assert!(!upload_id.is_empty(), "Expected non-empty upload ID"); + assert!( + !upload_id.upload_ref.is_empty(), + "Expected non-empty upload ID" + ); // Upload chunks let chunk1 = Bytes::from_static(b"first chunk "); diff --git a/nativelink-util/src/buf_channel.rs b/nativelink-util/src/buf_channel.rs index ad3b8c288..9cf13677c 100644 --- a/nativelink-util/src/buf_channel.rs +++ b/nativelink-util/src/buf_channel.rs @@ -39,12 +39,14 @@ pub fn make_buf_channel_pair() -> (DropCloserWriteHalf, DropCloserReadHalf) { // a little time for another thread to wake up and consume data if another // thread is pumping large amounts of data into the channel. let (tx, rx) = mpsc::channel(2); + let (recv_tx, recv_rx) = tokio::sync::oneshot::channel(); let eof_sent = Arc::new(AtomicBool::new(false)); ( DropCloserWriteHalf { tx: Some(tx), bytes_written: 0, eof_sent: eof_sent.clone(), + recv_rx: Some(recv_rx), }, DropCloserReadHalf { rx, @@ -54,6 +56,7 @@ pub fn make_buf_channel_pair() -> (DropCloserWriteHalf, DropCloserReadHalf) { bytes_received: 0, recent_data: Vec::new(), max_recent_data_size: 0, + recv_tx: Some(recv_tx), }, ) } @@ -64,6 +67,7 @@ pub struct DropCloserWriteHalf { tx: Option>, bytes_written: u64, eof_sent: Arc, + recv_rx: Option>, } impl DropCloserWriteHalf { @@ -72,6 +76,18 @@ impl DropCloserWriteHalf { self.send_get_bytes_on_error(buf).map_err(|err| err.0) } + /// Returns when the DropCloserReadHalf has called recv() for the first time. + pub async fn is_waiting(&mut self) -> Result<(), Error> { + let Some(recv_rx) = self.recv_rx.take() else { + // Once it's None then it's already been successful. + return Ok(()); + }; + match recv_rx.await { + Ok(()) => Ok(()), + Err(_err) => Err(make_err!(Code::Internal, "Dropped before recv")), + } + } + /// Sends data over the channel to the receiver. #[inline] async fn send_get_bytes_on_error(&mut self, buf: Bytes) -> Result<(), (Error, Bytes)> { @@ -207,6 +223,8 @@ pub struct DropCloserReadHalf { /// Amount of data to keep in the `recent_data` buffer before clearing it /// and no longer populating it. max_recent_data_size: u64, + /// A one shot that's sent when the first call to recv() is called. + recv_tx: Option>, } impl DropCloserReadHalf { @@ -238,6 +256,9 @@ impl DropCloserReadHalf { /// Try to receive a chunk of data, returning `None` if none is available. pub fn try_recv(&mut self) -> Option> { + if let Some(recv_tx) = self.recv_tx.take() { + let _ = recv_tx.send(()); + } if let Some(err) = &self.last_err { return Some(Err(err.clone())); }