diff --git a/alioth/src/board/board.rs b/alioth/src/board/board.rs index 34a47ebe..bea1809b 100644 --- a/alioth/src/board/board.rs +++ b/alioth/src/board/board.rs @@ -23,7 +23,7 @@ mod x86_64; use std::collections::HashMap; use std::ffi::CStr; use std::sync::Arc; -use std::sync::mpsc::{Receiver, Sender}; +use std::sync::mpsc::Sender; use std::thread::JoinHandle; use libc::{MAP_PRIVATE, MAP_SHARED}; @@ -97,6 +97,8 @@ pub enum Error { NotifyVmm, #[snafu(display("Another VCPU thread has signaled failure"))] PeerFailure, + #[snafu(display("Unexpected state: {state:?}, want {want:?}"))] + UnexpectedState { state: BoardState, want: BoardState }, } type Result = std::result::Result; @@ -157,17 +159,17 @@ impl CpuConfig { } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum BoardState { +pub enum BoardState { Created, Running, Shutdown, RebootPending, + Fatal, } #[derive(Debug)] struct MpSync { state: BoardState, - fatal: bool, count: u16, } @@ -189,8 +191,8 @@ impl BoardConfig { } } -type VcpuGuard<'a> = RwLockReadGuard<'a, Vec<(JoinHandle>, Sender<()>)>>; -type VcpuHandle = (JoinHandle>, Sender<()>); +type VcpuGuard<'a> = RwLockReadGuard<'a, Vec>; +type VcpuHandle = JoinHandle>; pub struct Board where @@ -243,19 +245,23 @@ where mp_sync: Mutex::new(MpSync { state: BoardState::Created, count: 0, - fatal: false, }), cond_var: Condvar::new(), } } pub fn boot(&self) -> Result<()> { - let vcpus = self.vcpus.read(); let mut mp_sync = self.mp_sync.lock(); - mp_sync.state = BoardState::Running; - for (_, boot_tx) in vcpus.iter() { - boot_tx.send(()).unwrap(); + if mp_sync.state == BoardState::Created { + mp_sync.state = BoardState::Running; + } else { + return error::UnexpectedState { + state: mp_sync.state, + want: BoardState::Created, + } + .fail(); } + self.cond_var.notify_all(); Ok(()) } @@ -353,7 +359,7 @@ where fn sync_vcpus(&self, vcpus: &VcpuGuard) -> Result<()> { let mut mp_sync = self.mp_sync.lock(); - if mp_sync.fatal { + if mp_sync.state == BoardState::Fatal { return error::PeerFailure.fail(); } @@ -365,24 +371,35 @@ where self.cond_var.wait(&mut mp_sync) } - if mp_sync.fatal { + if mp_sync.state == BoardState::Fatal { return error::PeerFailure.fail(); } Ok(()) } - fn run_vcpu_inner( - &self, - index: u16, - vcpu: &mut V::Vcpu, - boot_rx: &Receiver<()>, - ) -> Result<(), Error> { - self.init_vcpu(index, vcpu)?; - boot_rx.recv().unwrap(); - if self.mp_sync.lock().state != BoardState::Running { + fn notify_vmm(&self, index: u16, event_tx: &Sender) -> Result<()> { + if event_tx.send(index).is_err() { + error::NotifyVmm.fail() + } else { + Ok(()) + } + } + + fn run_vcpu_inner(&self, index: u16, event_tx: &Sender) -> Result<(), Error> { + let mut vcpu = self.create_vcpu(index)?; + self.notify_vmm(index, event_tx)?; + self.init_vcpu(index, &mut vcpu)?; + + let mut mp_sync = self.mp_sync.lock(); + while mp_sync.state == BoardState::Created { + self.cond_var.wait(&mut mp_sync); + } + if mp_sync.state != BoardState::Running { return Ok(()); } + drop(mp_sync); + loop { let vcpus = self.vcpus.read(); self.coco_init(index)?; @@ -397,15 +414,15 @@ where } self.add_pci_devs()?; let init_state = self.load_payload()?; - self.init_boot_vcpu(vcpu, &init_state)?; + self.init_boot_vcpu(&mut vcpu, &init_state)?; self.create_firmware_data(&init_state)?; } - self.init_ap(index, vcpu, &vcpus)?; + self.init_ap(index, &mut vcpu, &vcpus)?; self.coco_finalize(index, &vcpus)?; self.sync_vcpus(&vcpus)?; drop(vcpus); - let maybe_reboot = self.vcpu_loop(vcpu, index); + let maybe_reboot = self.vcpu_loop(&mut vcpu, index); let vcpus = self.vcpus.read(); let mut mp_sync = self.mp_sync.lock(); @@ -415,7 +432,7 @@ where } else { BoardState::Shutdown }; - for (another, (handle, _)) in vcpus.iter().enumerate() { + for (another, handle) in vcpus.iter().enumerate() { if index == another as u16 { continue; } @@ -434,7 +451,7 @@ where self.pci_bus.segment.reset().context(error::ResetPci)?; self.memory.reset()?; } - self.reset_vcpu(index, vcpu)?; + self.reset_vcpu(index, &mut vcpu)?; if let Err(e) = maybe_reboot { break Err(e); @@ -448,29 +465,19 @@ where } } - fn create_vcpu(&self, index: u16, event_tx: &Sender) -> Result { + fn create_vcpu(&self, index: u16) -> Result { let identity = self.encode_cpu_identity(index); let vcpu = self .vm .create_vcpu(index, identity) .context(error::CreateVcpu { index })?; - if event_tx.send(index).is_err() { - error::NotifyVmm.fail() - } else { - Ok(vcpu) - } + Ok(vcpu) } - pub fn run_vcpu( - &self, - index: u16, - event_tx: Sender, - boot_rx: Receiver<()>, - ) -> Result<(), Error> { - let mut vcpu = self.create_vcpu(index, &event_tx)?; + pub fn run_vcpu(&self, index: u16, event_tx: Sender) -> Result<(), Error> { + let ret = self.run_vcpu_inner(index, &event_tx); - let ret = self.run_vcpu_inner(index, &mut vcpu, &boot_rx); - event_tx.send(index).unwrap(); + let _ = self.notify_vmm(index, &event_tx); if matches!(ret, Ok(_) | Err(Error::PeerFailure { .. })) { return Ok(()); @@ -478,7 +485,7 @@ where log::warn!("VCPU-{index} reported error, unblocking other VCPUs..."); let mut mp_sync = self.mp_sync.lock(); - mp_sync.fatal = true; + mp_sync.state = BoardState::Fatal; if mp_sync.count > 0 { self.cond_var.notify_all(); } diff --git a/alioth/src/vm.rs b/alioth/src/vm.rs index cd0bc34b..af374795 100644 --- a/alioth/src/vm.rs +++ b/alioth/src/vm.rs @@ -140,18 +140,17 @@ where let mut vcpus = board.vcpus.write(); for index in 0..board.config.cpu.count { - let (boot_tx, boot_rx) = mpsc::channel(); let event_tx = event_tx.clone(); let board = board.clone(); let handle = thread::Builder::new() .name(format!("vcpu_{index}")) - .spawn(move || board.run_vcpu(index, event_tx, boot_rx)) + .spawn(move || board.run_vcpu(index, event_tx)) .context(error::VcpuThread { index })?; if event_rx.recv_timeout(Duration::from_secs(2)).is_err() { let err = std::io::ErrorKind::TimedOut.into(); Err(err).context(error::VcpuThread { index })?; } - vcpus.push((handle, boot_tx)); + vcpus.push(handle); } drop(vcpus); @@ -294,7 +293,7 @@ where drop(vcpus); let mut vcpus = self.board.vcpus.write(); let mut ret = Ok(()); - for (index, (handle, _)) in vcpus.drain(..).enumerate() { + for (index, handle) in vcpus.drain(..).enumerate() { let Ok(r) = handle.join() else { log::error!("Cannot join VCPU-{index}"); continue;