Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 52 additions & 22 deletions nativelink-worker/src/local_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,35 @@ struct LocalWorkerImpl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsM
metrics: Arc<Metrics>,
}

#[derive(Debug)]
struct ActionsInTransitGuard {
counter: Arc<AtomicU64>,
active: bool,
}

impl ActionsInTransitGuard {
fn new(counter: Arc<AtomicU64>) -> Self {
counter.fetch_add(1, Ordering::Release);
Self {
counter,
active: true,
}
}

fn done(&mut self) {
if self.active {
self.counter.fetch_sub(1, Ordering::Release);
self.active = false;
}
}
}

impl Drop for ActionsInTransitGuard {
fn drop(&mut self) {
self.done();
}
}

async fn preconditions_met(precondition_script: Option<String>) -> Result<(), Error> {
let Some(precondition_script) = &precondition_script else {
// No script means we are always ok to proceed.
Expand Down Expand Up @@ -254,7 +283,8 @@ impl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorke

let start_action_fut = {
let precondition_script_cfg = self.config.experimental_precondition_script.clone();
let actions_in_transit = self.actions_in_transit.clone();
let mut actions_in_transit_guard =
ActionsInTransitGuard::new(self.actions_in_transit.clone());
let worker_id = self.worker_id.clone();
let running_actions_manager = self.running_actions_manager.clone();
let mut grpc_client = self.grpc_client.clone();
Expand All @@ -265,9 +295,7 @@ impl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorke
metrics.preconditions.wrap(preconditions_met(precondition_script_cfg))
.and_then(|()| running_actions_manager.create_and_add_action(worker_id, start_execute))
.map(move |r| {
// Now that we either failed or registered our action, we can
// consider the action to no longer be in transit.
actions_in_transit.fetch_sub(1, Ordering::Release);
actions_in_transit_guard.done();
r
})
.and_then(|action| {
Expand Down Expand Up @@ -339,8 +367,6 @@ impl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorke
}
};

self.actions_in_transit.fetch_add(1, Ordering::Release);

let add_future_channel = add_future_channel.clone();

info_span!(
Expand Down Expand Up @@ -683,24 +709,28 @@ impl<T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorker<T,
);

// Now listen for connections and run all other services.
if let Err(err) = inner.run(update_for_worker_stream, &mut shutdown_rx).await {
'no_more_actions: {
// Ensure there are no actions in transit before we try to kill
// all our actions.
const ITERATIONS: usize = 1_000;

const ERROR_MSG: &str = "Actions in transit did not reach zero before we disconnected from the scheduler";

let sleep_duration = ACTIONS_IN_TRANSIT_TIMEOUT_S / ITERATIONS as f32;
for _ in 0..ITERATIONS {
if inner.actions_in_transit.load(Ordering::Acquire) == 0 {
break 'no_more_actions;
}
(sleep_fn_pin)(Duration::from_secs_f32(sleep_duration)).await;
if let Err(mut err) = inner.run(update_for_worker_stream, &mut shutdown_rx).await {
// Ensure there are no actions in transit before we try to kill
// all our actions. If they refuse to drain, forcibly reset the
// counter so we can keep cleaning up.
const ITERATIONS: usize = 1_000;
const ERROR_MSG: &str = "Actions in transit did not reach zero before we disconnected from the scheduler";

let sleep_duration = ACTIONS_IN_TRANSIT_TIMEOUT_S / ITERATIONS as f32;
for _ in 0..ITERATIONS {
if inner.actions_in_transit.load(Ordering::Acquire) == 0 {
break;
}
error!(ERROR_MSG);
return Err(err.append(ERROR_MSG));
(sleep_fn_pin)(Duration::from_secs_f32(sleep_duration)).await;
}

let actions_in_transit = inner.actions_in_transit.load(Ordering::Acquire);
if actions_in_transit != 0 {
error!(actions_in_transit, ERROR_MSG);
inner.actions_in_transit.store(0, Ordering::Release);
err = err.append(ERROR_MSG);
}

error!(?err, "Worker disconnected from scheduler");
// Kill off any existing actions because if we re-connect, we'll
// get some more and it might resource lock us.
Expand Down
52 changes: 36 additions & 16 deletions nativelink-worker/src/running_actions_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -633,18 +633,25 @@ async fn do_cleanup(

debug!("Worker cleaning up");
// Note: We need to be careful to keep trying to cleanup even if one of the steps fails.
let remove_dir_result = fs::remove_dir_all(action_directory)
.await
.err_tip(|| format!("Could not remove working directory {action_directory}"));

if let Err(err) = running_actions_manager.cleanup_action(operation_id) {
error!(%operation_id, ?err, "Error cleaning up action");
Result::<(), Error>::Err(err).merge(remove_dir_result)
} else if let Err(err) = remove_dir_result {
error!(%operation_id, ?err, "Error removing working directory");
Err(err)
} else {
Ok(())
let remove_dir_result = match fs::remove_dir_all(action_directory).await {
Ok(()) => Ok(()),
Err(err) if err.code == Code::NotFound => Ok(()),
Err(err) => {
Err(err).err_tip(|| format!("Could not remove working directory {action_directory}"))
}
};

let cleanup_result = running_actions_manager.cleanup_action(operation_id);
match (cleanup_result, remove_dir_result) {
(Err(err), remove_dir_result) => {
error!(%operation_id, ?err, "Error cleaning up action");
Err::<(), Error>(err).merge(remove_dir_result)
}
(Ok(()), Err(err)) => {
error!(%operation_id, ?err, "Error removing working directory");
Err(err)
}
_ => Ok(()),
}
}

Expand Down Expand Up @@ -2053,14 +2060,25 @@ impl RunningActionsManagerImpl {
})
}

#[allow(
clippy::unnecessary_wraps,
reason = "We keep a Result here to preserve the existing API and future-proof error handling."
)]
fn cleanup_action(&self, operation_id: &OperationId) -> Result<(), Error> {
let mut running_actions = self.running_actions.lock();
let result = running_actions.remove(operation_id).err_tip(|| {
format!("Expected operation id '{operation_id}' to exist in RunningActionsManagerImpl")
});
if running_actions.remove(operation_id).is_none() {
warn!(
%operation_id,
"Cleanup requested for operation that was not tracked"
);
self.metrics.cleanup_missing_action.inc();
// No need to copy anything, we just are telling the receivers an event happened.
self.action_done_tx.send_modify(|()| {});
return Ok(());
}
// No need to copy anything, we just are telling the receivers an event happened.
self.action_done_tx.send_modify(|()| {});
result.map(|_| ())
Ok(())
}

// Note: We do not capture metrics on this call, only `.kill_all()`.
Expand Down Expand Up @@ -2260,6 +2278,8 @@ pub struct Metrics {
get_finished_result: AsyncCounterWrapper,
#[metric(help = "Number of times an action waited for cleanup to complete.")]
cleanup_waits: CounterWithTime,
#[metric(help = "Number of cleanup calls where the action was already missing.")]
cleanup_missing_action: CounterWithTime,
#[metric(help = "Number of stale directories removed during action retries.")]
stale_removals: CounterWithTime,
#[metric(help = "Number of timeouts while waiting for cleanup to complete.")]
Expand Down
76 changes: 76 additions & 0 deletions nativelink-worker/tests/local_worker_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ use pretty_assertions::assert_eq;
use prost::Message;
use rand::Rng;
use tokio::io::AsyncWriteExt;
use tokio::time::timeout;
use utils::local_worker_test_utils::{
setup_grpc_stream, setup_local_worker, setup_local_worker_with_config,
};
Expand Down Expand Up @@ -290,6 +291,81 @@ async fn blake3_digest_function_registered_properly() -> Result<(), Error> {
Ok(())
}

#[nativelink_test]
async fn disconnect_with_action_in_transit_does_not_panic() -> Result<(), Error> {
let mut test_context = setup_local_worker(HashMap::new()).await;
let streaming_response = test_context.maybe_streaming_response.take().unwrap();

{
let props = test_context
.client
.expect_connect_worker(Ok(streaming_response))
.await;
assert_eq!(props, ConnectWorkerRequest::default());
}

let expected_worker_id = "foobar".to_string();
let tx_stream = test_context.maybe_tx_stream.take().unwrap();
{
tx_stream
.send(Frame::data(
encode_stream_proto(&UpdateForWorker {
update: Some(Update::ConnectionResult(ConnectionResult {
worker_id: expected_worker_id.clone(),
})),
})
.unwrap(),
))
.await
.map_err(|e| make_input_err!("Could not send : {:?}", e))?;
}

{
tx_stream
.send(Frame::data(
encode_stream_proto(&UpdateForWorker {
update: Some(Update::StartAction(StartExecute {
execute_request: Some(nativelink_proto::build::bazel::remote::execution::v2::ExecuteRequest {
action_digest: None,
digest_function: nativelink_proto::build::bazel::remote::execution::v2::digest_function::Value::Sha256 as i32,
..Default::default()
}),
operation_id: "pending-op".to_string(),
queued_timestamp: None,
platform: None,
worker_id: expected_worker_id,
})),
})
.unwrap(),
))
.await
.map_err(|e| make_input_err!("Could not send : {:?}", e))?;
}

// Ensure the start action is in-flight but do not respond so it stays pending.
let (_worker_id, pending_start_execute) = test_context
.actions_manager
.wait_for_create_and_add_action_call()
.await;
assert_eq!(pending_start_execute.operation_id, "pending-op");

drop(tx_stream);

timeout(
Duration::from_secs(2),
test_context.actions_manager.expect_kill_all(),
)
.await
.expect("kill_all should be called when disconnecting with pending actions");

// Unblock any pending create_and_add_action future so the worker can settle.
test_context
.actions_manager
.send_create_and_add_action_response(Err(make_input_err!("Disconnected")));

Ok(())
}

#[nativelink_test]
async fn simple_worker_start_action_test() -> Result<(), Error> {
let mut test_context = setup_local_worker(HashMap::new()).await;
Expand Down
91 changes: 91 additions & 0 deletions nativelink-worker/tests/running_actions_manager_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,97 @@ mod tests {
Ok(())
}

#[nativelink_test]
async fn cleanup_is_idempotent_after_first_call() -> Result<(), Box<dyn core::error::Error>> {
const WORKER_ID: &str = "foo_worker_id";

fn test_monotonic_clock() -> SystemTime {
static CLOCK: AtomicU64 = AtomicU64::new(0);
monotonic_clock(&CLOCK)
}

let (_, _, cas_store, ac_store) = setup_stores().await?;
let root_action_directory = make_temp_path("root_action_directory");
fs::create_dir_all(&root_action_directory).await?;

let running_actions_manager = Arc::new(RunningActionsManagerImpl::new_with_callbacks(
RunningActionsManagerArgs {
root_action_directory: root_action_directory.clone(),
execution_configuration: ExecutionConfiguration::default(),
cas_store: cas_store.clone(),
ac_store: Some(Store::new(ac_store.clone())),
historical_store: Store::new(cas_store.clone()),
upload_action_result_config:
&nativelink_config::cas_server::UploadActionResultConfig {
upload_ac_results_strategy:
nativelink_config::cas_server::UploadCacheResultsStrategy::Never,
..Default::default()
},
max_action_timeout: Duration::MAX,
timeout_handled_externally: false,
directory_cache: None,
},
Callbacks {
now_fn: test_monotonic_clock,
sleep_fn: |_duration| Box::pin(future::pending()),
},
)?);
let command = Command {
arguments: vec!["echo".to_string(), "hello".to_string()],
output_paths: vec![],
..Default::default()
};
let command_digest = serialize_and_upload_message(
&command,
cas_store.as_pin(),
&mut DigestHasherFunc::Sha256.hasher(),
)
.await?;
let input_root_digest = serialize_and_upload_message(
&Directory::default(),
cas_store.as_pin(),
&mut DigestHasherFunc::Sha256.hasher(),
)
.await?;
let action = Action {
command_digest: Some(command_digest.into()),
input_root_digest: Some(input_root_digest.into()),
..Default::default()
};
let action_digest = serialize_and_upload_message(
&action,
cas_store.as_pin(),
&mut DigestHasherFunc::Sha256.hasher(),
)
.await?;

let queued_timestamp = make_system_time(1000);
let operation_id = OperationId::default().to_string();

let running_action = running_actions_manager
.create_and_add_action(
WORKER_ID.to_string(),
StartExecute {
execute_request: Some(ExecuteRequest {
action_digest: Some(action_digest.into()),
digest_function: ProtoDigestFunction::Sha256 as i32,
..Default::default()
}),
operation_id,
queued_timestamp: Some(queued_timestamp.into()),
platform: action.platform.clone(),
worker_id: WORKER_ID.to_string(),
},
)
.await?;

running_action.clone().cleanup().await?;
running_action.clone().cleanup().await?;

running_actions_manager.kill_all().await;
Ok(())
}

#[nativelink_test]
async fn kill_ends_action() -> Result<(), Box<dyn core::error::Error>> {
const WORKER_ID: &str = "foo_worker_id";
Expand Down
Loading
Loading