Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 71 additions & 9 deletions nativelink-store/src/fast_slow_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use core::cmp::{max, min};
use core::ops::Range;
use core::pin::Pin;
use core::sync::atomic::{AtomicU64, Ordering};
use core::time::Duration;
use std::collections::HashMap;
use std::ffi::OsString;
use std::sync::{Arc, Weak};
Expand Down Expand Up @@ -63,6 +64,9 @@ pub struct FastSlowStore {
// actually it's faster because we're not downloading the file multiple
// times are doing loads of duplicate IO.
populating_digests: Mutex<HashMap<StoreKey<'static>, Loader>>,
// The amount of time to allow stores to start before determining that they
// have deadlocked and retrying.
deadlock_timeout: Duration,
}

// This guard ensures that the populating_digests is cleared even if the future
Expand Down Expand Up @@ -116,6 +120,15 @@ impl Drop for LoaderGuard<'_> {

impl FastSlowStore {
pub fn new(spec: &FastSlowSpec, fast_store: Store, slow_store: Store) -> Arc<Self> {
Self::new_with_deadlock_timeout(spec, fast_store, slow_store, Duration::from_secs(5))
}

pub fn new_with_deadlock_timeout(
spec: &FastSlowSpec,
fast_store: Store,
slow_store: Store,
deadlock_timeout: Duration,
) -> Arc<Self> {
Arc::new_cyclic(|weak_self| Self {
fast_store,
fast_direction: spec.fast_direction,
Expand All @@ -124,6 +137,7 @@ impl FastSlowStore {
weak_self: weak_self.clone(),
metrics: FastSlowStoreMetrics::default(),
populating_digests: Mutex::new(HashMap::new()),
deadlock_timeout,
})
}

Expand Down Expand Up @@ -189,8 +203,62 @@ impl FastSlowStore {
let send_range = offset..length.map_or(u64::MAX, |length| length + offset);
let mut bytes_received: u64 = 0;

let (mut fast_tx, fast_rx) = make_buf_channel_pair();
let (slow_tx, mut slow_rx) = make_buf_channel_pair();
// There's a strong possibility of a deadlock here as we're working with multiple
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important fix.

// stores. We need to be careful that we don't hold a get semaphore if we can't
// open the update. This doesn't know anything about the downstream implementations,
// so simply makes use of a timeout to check that the reader and writers are set up.
let (stores_fut, mut slow_rx, mut fast_tx) = loop {
let (mut fast_tx, fast_rx) = make_buf_channel_pair();
let (slow_tx, mut slow_rx) = make_buf_channel_pair();

let slow_store_fut = self.slow_store.get(key.borrow(), slow_tx);
let fast_store_fut =
self.fast_store
.update(key.borrow(), fast_rx, UploadSizeInfo::ExactSize(sz));
let mut stores_fut = futures::future::join(slow_store_fut, fast_store_fut);
let has_semaphores_fut = tokio::time::timeout(
self.deadlock_timeout,
futures::future::join(slow_rx.peek(), fast_tx.is_waiting()),
);
tokio::select! {
result = &mut stores_fut => {
match result {
(Ok(()), Ok(())) => {
// Both stores completed without the writers, probably zero byte.
return Ok(());
}
(Ok(()), Err(err)) | (Err(err), Ok(())) => {
return Err(err);
}
(Err(err1), Err(err2)) => {
return Err(err1.merge(err2));
}
}
}
result = has_semaphores_fut => {
match result {
Ok((Ok(_), Ok(()))) => {
// Both sides have started reading/writing, we assume they hold
// all the permits they require and it's safe to continue.
break (stores_fut, slow_rx, fast_tx);
}
Ok((Ok(_), Err(err)) | (Err(err), Ok(()))) => {
return Err(err);
}
Ok((Err(err1), Err(err2))) => {
return Err(err1.merge(err2));
}
Err(_timeout) => {
// There was probably a deadlock... we need to drop and try again.
drop(stores_fut);
tracing::warn!("Possible deadlock in fast-slow, retrying.");
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
}
};
};

let data_stream_fut = async move {
let mut maybe_writer_pin = maybe_writer.map(Pin::new);
loop {
Expand Down Expand Up @@ -229,13 +297,7 @@ impl FastSlowStore {
}
};

let slow_store_fut = self.slow_store.get(key.borrow(), slow_tx);
let fast_store_fut =
self.fast_store
.update(key.borrow(), fast_rx, UploadSizeInfo::ExactSize(sz));

let (data_stream_res, slow_res, fast_res) =
join!(data_stream_fut, slow_store_fut, fast_store_fut);
let (data_stream_res, (slow_res, fast_res)) = join!(data_stream_fut, stores_fut);
match data_stream_res {
Ok((fast_eof_res, maybe_writer_pin)) =>
// Sending the EOF will drop us almost immediately in bytestream_server
Expand Down
98 changes: 58 additions & 40 deletions nativelink-store/src/gcs_client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ use crate::gcs_client::types::{
SIMPLE_UPLOAD_THRESHOLD, Timestamp,
};

#[derive(Debug)]
pub struct UploadRef {
pub upload_ref: String,
pub(crate) _permit: OwnedSemaphorePermit,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is key because if means that connections are only held during active uploads, I think.

}

/// A trait that defines the required GCS operations.
/// This abstraction allows for easier testing by mocking GCS responses.
pub trait GcsOperations: Send + Sync + Debug {
Expand Down Expand Up @@ -71,12 +77,12 @@ pub trait GcsOperations: Send + Sync + Debug {
fn start_resumable_write(
&self,
object_path: &ObjectPath,
) -> impl Future<Output = Result<String, Error>> + Send;
) -> impl Future<Output = Result<UploadRef, Error>> + Send;

/// Upload a chunk of data in a resumable upload session
fn upload_chunk(
&self,
upload_url: &str,
upload_url: &UploadRef,
object_path: &ObjectPath,
data: Bytes,
offset: u64,
Expand Down Expand Up @@ -336,12 +342,21 @@ impl GcsClient {
}

// Check if the object exists
match self.read_object_metadata(object_path).await? {
Some(_) => Ok(()),
None => Err(make_err!(
Code::Internal,
"Upload completed but object not found"
)),
let request = GetObjectRequest {
bucket: object_path.bucket.clone(),
object: object_path.path.clone(),
..Default::default()
};

match self.client.get_object(&request).await {
Ok(_) => Ok(()),
Err(GcsError::Response(resp)) if resp.code == 404 => {
return Err(make_err!(
Code::Internal,
"Upload completed but object not found"
));
}
Err(err) => Err(Self::handle_gcs_error(&err)),
}
})
.await
Expand Down Expand Up @@ -470,55 +485,58 @@ impl GcsOperations for GcsClient {
.await
}

async fn start_resumable_write(&self, object_path: &ObjectPath) -> Result<String, Error> {
self.with_connection(|| async {
let request = UploadObjectRequest {
bucket: object_path.bucket.clone(),
..Default::default()
};
async fn start_resumable_write(&self, object_path: &ObjectPath) -> Result<UploadRef, Error> {
let permit =
self.semaphore.clone().acquire_owned().await.map_err(|e| {
make_err!(Code::Internal, "Failed to acquire connection permit: {}", e)
})?;
let request = UploadObjectRequest {
bucket: object_path.bucket.clone(),
..Default::default()
};

let upload_type = UploadType::Multipart(Box::new(Object {
name: object_path.path.clone(),
content_type: Some(DEFAULT_CONTENT_TYPE.to_string()),
..Default::default()
}));
let upload_type = UploadType::Multipart(Box::new(Object {
name: object_path.path.clone(),
content_type: Some(DEFAULT_CONTENT_TYPE.to_string()),
..Default::default()
}));

// Start resumable upload session
let uploader = self
.client
.prepare_resumable_upload(&request, &upload_type)
.await
.map_err(|e| Self::handle_gcs_error(&e))?;
// Start resumable upload session
let uploader = self
.client
.prepare_resumable_upload(&request, &upload_type)
.await
.map_err(|e| Self::handle_gcs_error(&e))?;

Ok(uploader.url().to_string())
Ok(UploadRef {
upload_ref: uploader.url().to_string(),
_permit: permit,
})
.await
}

async fn upload_chunk(
&self,
upload_url: &str,
upload_url: &UploadRef,
_object_path: &ObjectPath,
data: Bytes,
offset: u64,
end_offset: u64,
total_size: Option<u64>,
) -> Result<(), Error> {
self.with_connection(|| async {
let uploader = self.client.get_resumable_upload(upload_url.to_string());
let uploader = self
.client
.get_resumable_upload(upload_url.upload_ref.clone());

let last_byte = if end_offset == 0 { 0 } else { end_offset - 1 };
let chunk_def = ChunkSize::new(offset, last_byte, total_size);
let last_byte = if end_offset == 0 { 0 } else { end_offset - 1 };
let chunk_def = ChunkSize::new(offset, last_byte, total_size);

// Upload chunk
uploader
.upload_multiple_chunk(data, &chunk_def)
.await
.map_err(|e| Self::handle_gcs_error(&e))?;
// Upload chunk
uploader
.upload_multiple_chunk(data, &chunk_def)
.await
.map_err(|e| Self::handle_gcs_error(&e))?;

Ok(())
})
.await
Ok(())
}

async fn upload_from_reader(
Expand Down
16 changes: 10 additions & 6 deletions nativelink-store/src/gcs_client/mocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
use core::fmt::Debug;
use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};

use bytes::Bytes;
use futures::Stream;
use nativelink_error::{Code, Error, make_err};
use nativelink_util::buf_channel::DropCloserReadHalf;
use tokio::sync::RwLock;
use tokio::sync::{RwLock, Semaphore};

use crate::gcs_client::client::GcsOperations;
use crate::gcs_client::client::{GcsOperations, UploadRef};
use crate::gcs_client::types::{DEFAULT_CONTENT_TYPE, GcsObject, ObjectPath, Timestamp};

/// A mock implementation of `GcsOperations` for testing
Expand Down Expand Up @@ -379,7 +380,7 @@ impl GcsOperations for MockGcsOperations {
Ok(())
}

async fn start_resumable_write(&self, object_path: &ObjectPath) -> Result<String, Error> {
async fn start_resumable_write(&self, object_path: &ObjectPath) -> Result<UploadRef, Error> {
self.call_counts
.start_resumable_calls
.fetch_add(1, Ordering::Relaxed);
Expand All @@ -392,12 +393,15 @@ impl GcsOperations for MockGcsOperations {

self.handle_failure().await?;
let upload_id = format!("mock-upload-{}-{}", object_path.bucket, object_path.path);
Ok(upload_id)
Ok(UploadRef {
upload_ref: upload_id,
_permit: Arc::new(Semaphore::new(1)).acquire_owned().await.unwrap(),
})
}

async fn upload_chunk(
&self,
upload_url: &str,
upload_url: &UploadRef,
object_path: &ObjectPath,
data: Bytes,
offset: u64,
Expand All @@ -408,7 +412,7 @@ impl GcsOperations for MockGcsOperations {
.upload_chunk_calls
.fetch_add(1, Ordering::Relaxed);
self.requests.write().await.push(MockRequest::UploadChunk {
upload_url: upload_url.to_string(),
upload_url: upload_url.upload_ref.clone(),
object_path: object_path.clone(),
data_len: data.len(),
offset,
Expand Down
Loading
Loading