Skip to content

Commit

Permalink
Add graceful shutdown to worker instances
Browse files Browse the repository at this point in the history
  • Loading branch information
SchahinRohani committed Oct 6, 2024
1 parent e6eb5f7 commit 3eb360b
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 36 deletions.
2 changes: 1 addition & 1 deletion nativelink-service/src/worker_api_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
// limitations under the License.

use std::collections::HashMap;
use std::convert::Into;
use std::pin::Pin;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use std::convert::Into;

use futures::stream::unfold;
use futures::Stream;
Expand Down
117 changes: 97 additions & 20 deletions nativelink-worker/src/local_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::pin::Pin;
use std::process::Stdio;
use std::str;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Weak};
use std::sync::{Arc, Mutex, Weak};
use std::time::Duration;

use futures::future::BoxFuture;
Expand All @@ -28,7 +28,7 @@ use nativelink_metric::{MetricsComponent, RootMetricsComponent};
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::update_for_worker::Update;
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::worker_api_client::WorkerApiClient;
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{
execute_result, ExecuteResult, KeepAliveRequest, UpdateForWorker,
execute_result, ExecuteResult, GoingAwayRequest, KeepAliveRequest, UpdateForWorker,
};
use nativelink_store::fast_slow_store::FastSlowStore;
use nativelink_util::action_messages::{ActionResult, ActionStage, OperationId};
Expand Down Expand Up @@ -356,13 +356,16 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
}

type ConnectionFactory<T> = Box<dyn Fn() -> BoxFuture<'static, Result<T, Error>> + Send + Sync>;
type SleepFnType = Box<dyn Fn(Duration) -> BoxFuture<'static, ()> + Send + Sync>;

pub struct LocalWorker<T: WorkerApiClientTrait, U: RunningActionsManager> {
config: Arc<LocalWorkerConfig>,
running_actions_manager: Arc<U>,
connection_factory: ConnectionFactory<T>,
sleep_fn: Option<Box<dyn Fn(Duration) -> BoxFuture<'static, ()> + Send + Sync>>,
sleep_fn: Mutex<Option<SleepFnType>>,
metrics: Arc<Metrics>,
grpc_client: Mutex<Option<T>>,
worker_id: Mutex<Option<String>>,
}

/// Creates a new `LocalWorker`. The `cas_store` must be an instance of
Expand Down Expand Up @@ -468,17 +471,77 @@ impl<T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorker<T, U> {
config,
running_actions_manager,
connection_factory,
sleep_fn: Some(sleep_fn),
sleep_fn: Some(sleep_fn).into(),
metrics,
grpc_client: Mutex::new(None),
worker_id: Mutex::new(None),
}
}

pub fn name(&self) -> &String {
&self.config.name
}

pub async fn shutdown(&self) {
println!("Shutting down worker: {}", self.name());

let max_wait_duration = Duration::from_secs(60);

// Attempt to complete actions within the timeout
if (tokio::time::timeout(
max_wait_duration,
self.running_actions_manager.complete_actions(),
)
.await)
.is_ok()
{
println!("All actions completed before timeout.");
} else {
println!(
"Timeout of {} seconds reached: Some actions are still running during shutdown.",
max_wait_duration.as_secs()
);
// Forcefully terminate remaining actions
self.running_actions_manager.kill_all().await;
}

// Extract grpc_client and worker_id while holding the locks
let mut grpc_client = {
let mut grpc_client_lock = self.grpc_client.lock().unwrap();
if let Some(client) = grpc_client_lock.as_mut() {
client.clone()
} else {
println!("No grpc_client available; cannot notify scheduler.");
return;
}
};

let worker_id = {
let worker_id_lock = self.worker_id.lock().unwrap();
if let Some(id) = worker_id_lock.as_ref() {
id.clone()
} else {
println!("No worker_id available; cannot notify scheduler.");
return;
}
};

// Notify the scheduler with the going_away request
println!("Notify Scheduler: Local Worker is going away...");
if let Err(e) = grpc_client.going_away(GoingAwayRequest { worker_id }).await {
println!("Failed to send going_away to scheduler: {e:?}");
} else {
println!("Successfully notified scheduler that worker is going away.");
}

{
let mut grpc_client_lock = self.grpc_client.lock().unwrap();
*grpc_client_lock = None;
}
}

async fn register_worker(
&mut self,
&self,
client: &mut T,
) -> Result<(String, Streaming<UpdateForWorker>), Error> {
let supported_properties =
Expand Down Expand Up @@ -509,11 +572,13 @@ impl<T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorker<T, U> {
}

#[instrument(skip(self), level = Level::INFO)]
pub async fn run(mut self) -> Result<(), Error> {
let sleep_fn = self
.sleep_fn
.take()
.err_tip(|| "Could not unwrap sleep_fn in LocalWorker::run")?;
pub async fn run(self: Arc<Self>) -> Result<(), Error> {
let sleep_fn = self.sleep_fn.lock().unwrap().take().ok_or_else(|| {
make_err!(
Code::Internal,
"Could not unwrap sleep_fn in LocalWorker::run"
)
})?;
let sleep_fn_pin = Pin::new(&sleep_fn);
let error_handler = Box::pin(move |err| async move {
event!(Level::ERROR, ?err, "Error");
Expand All @@ -530,23 +595,35 @@ impl<T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorker<T, U> {
}
};

{
let mut grpc_client_lock = self.grpc_client.lock().unwrap();
*grpc_client_lock = Some(client.clone());
}

// Next register our worker with the scheduler.
let (mut inner, update_for_worker_stream) =
match self.register_worker(&mut client).await {
Err(e) => {
(error_handler)(e).await;
continue; // Try to connect again.
}
Ok((worker_id, update_for_worker_stream)) => (
LocalWorkerImpl::new(
&self.config,
client,
worker_id,
self.running_actions_manager.clone(),
self.metrics.clone(),
),
update_for_worker_stream,
),
Ok((worker_id, update_for_worker_stream)) => {
// Store the worker_id
{
let mut worker_id_lock = self.worker_id.lock().unwrap();
*worker_id_lock = Some(worker_id.clone());
}
(
LocalWorkerImpl::new(
&self.config,
client,
worker_id,
self.running_actions_manager.clone(),
self.metrics.clone(),
),
update_for_worker_stream,
)
}
};
event!(
Level::WARN,
Expand Down
19 changes: 18 additions & 1 deletion nativelink-worker/src/running_actions_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ use std::ffi::{OsStr, OsString};
use std::fmt::Debug;
#[cfg(target_family = "unix")]
use std::fs::Permissions;
#[cfg(target_family = "unix")]
use std::os::unix::fs::{MetadataExt, PermissionsExt};
use std::path::Path;
use std::pin::Pin;
Expand Down Expand Up @@ -1349,6 +1348,8 @@ pub trait RunningActionsManager: Sync + Send + Sized + Unpin + 'static {
hasher: DigestHasherFunc,
) -> impl Future<Output = Result<(), Error>> + Send;

fn complete_actions(&self) -> impl Future<Output = ()> + Send;

fn kill_all(&self) -> impl Future<Output = ()> + Send;

fn kill_operation(
Expand Down Expand Up @@ -1879,6 +1880,22 @@ impl RunningActionsManager for RunningActionsManagerImpl {
Ok(())
}

async fn complete_actions(&self) {
let mut receiver = self.action_done_tx.subscribe();
loop {
{
let running_actions = self.running_actions.lock();
if running_actions.is_empty() {
break;
}
}
// Wait for a change in the action_done_tx
if receiver.changed().await.is_err() {
break;
}
}
}

// Note: When the future returns the process should be fully killed and cleaned up.
async fn kill_all(&self) {
self.metrics
Expand Down
5 changes: 4 additions & 1 deletion nativelink-worker/tests/utils/local_worker_test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,10 @@ pub async fn setup_local_worker_with_config(local_worker_config: LocalWorkerConf
}),
Box::new(move |_| Box::pin(async move { /* No sleep */ })),
);
let drop_guard = spawn!("local_worker_spawn", async move { worker.run().await });
let drop_guard = spawn!(
"local_worker_spawn",
async move { Arc::new(worker).run().await }
);

let (tx_stream, streaming_response) = setup_grpc_stream();
TestContext {
Expand Down
6 changes: 6 additions & 0 deletions nativelink-worker/tests/utils/mock_running_actions_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use futures::Future;
use std::future;
use std::sync::Arc;

use async_lock::Mutex;
Expand Down Expand Up @@ -163,6 +165,10 @@ impl RunningActionsManager for MockRunningActionsManager {
Ok(())
}

fn complete_actions(&self) -> impl Future<Output = ()> + Send {
future::ready(())
}

async fn kill_operation(&self, operation_id: &OperationId) -> Result<(), Error> {
self.tx_kill_operation
.send(operation_id.clone())
Expand Down
67 changes: 54 additions & 13 deletions src/bin/nativelink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ use nativelink_util::store_trait::{
};
use nativelink_util::task::TaskExecutor;
use nativelink_util::{background_spawn, init_tracing, spawn, spawn_blocking};
use nativelink_worker::local_worker::new_local_worker;
use nativelink_worker::local_worker::{new_local_worker, LocalWorker};
use nativelink_worker::running_actions_manager::RunningActionsManagerImpl;
use nativelink_worker::worker_api_client_wrapper::WorkerApiClientWrapper;
use opentelemetry::metrics::MeterProvider;
use opentelemetry_sdk::metrics::SdkMeterProvider;
use parking_lot::{Mutex, RwLock};
Expand Down Expand Up @@ -160,6 +162,7 @@ impl RootMetricsComponent for ConnectedClientsMetrics {}
async fn inner_main(
cfg: CasConfig,
server_start_timestamp: u64,
instance: Arc<AsyncMutex<Instance>>,
) -> Result<(), Box<dyn std::error::Error>> {
let health_registry_builder = Arc::new(AsyncMutex::new(HealthRegistryBuilder::new(
"nativelink".into(),
Expand Down Expand Up @@ -916,11 +919,19 @@ async fn inner_main(
)
.await
.err_tip(|| "Could not make LocalWorker")?;

let local_worker = Arc::new(local_worker);
{
let mut instance_lock = instance.lock().await;
instance_lock.workers.push(local_worker.clone());
}

let name = if local_worker.name().is_empty() {
format!("worker_{i}")
} else {
local_worker.name().clone()
};

if worker_names.contains(&name) {
Err(Box::new(make_err!(
Code::InvalidArgument,
Expand Down Expand Up @@ -960,6 +971,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {

let mut cfg = futures::executor::block_on(get_config())?;

// Initialize the Instance empty
let instance = Arc::new(AsyncMutex::new(Instance::new(vec![])));

let (mut metrics_enabled, max_blocking_threads) = {
// Note: If the default changes make sure you update the documentation in
// `config/cas_server.rs`.
Expand Down Expand Up @@ -1022,27 +1036,54 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.enable_all()
.on_thread_start(move || set_metrics_enabled_for_this_thread(metrics_enabled))
.build()?;
let instance_for_ctrl_c = instance.clone();
runtime.spawn(async move {
tokio::signal::ctrl_c()
.await
.expect("Failed to listen to SIGINT");
let instance_guard = instance_for_ctrl_c.lock().await;
instance_guard.graceful_shutdown().await;
eprintln!("User terminated process via SIGINT");
std::process::exit(130);
});

#[cfg(target_family = "unix")]
runtime.spawn(async move {
signal(SignalKind::terminate())
.expect("Failed to listen to SIGTERM")
.recv()
.await;
eprintln!("Process terminated via SIGTERM");
std::process::exit(143);
});
{
let instance_for_sigterm = instance.clone();
runtime.spawn(async move {
signal(SignalKind::terminate())
.expect("Failed to listen to SIGTERM")
.recv()
.await;

runtime.block_on(
Arc::new(OriginContext::new())
.wrap_async(trace_span!("main"), inner_main(cfg, server_start_time)),
)
let instance_guard = instance_for_sigterm.lock().await;
instance_guard.graceful_shutdown().await;
eprintln!("Process terminated via SIGTERM");
std::process::exit(143);
});
}

runtime.block_on(Arc::new(OriginContext::new()).wrap_async(
trace_span!("main"),
inner_main(cfg, server_start_time, instance.clone()),
))
}
}

struct Instance {
workers: Vec<Arc<LocalWorker<WorkerApiClientWrapper, RunningActionsManagerImpl>>>,
}

impl Instance {
fn new(
workers: Vec<Arc<LocalWorker<WorkerApiClientWrapper, RunningActionsManagerImpl>>>,
) -> Self {
Instance { workers }
}

async fn graceful_shutdown(&self) {
for worker in &self.workers {
worker.shutdown().await;
}
}
}

0 comments on commit 3eb360b

Please sign in to comment.