diff --git a/chainfire/chainfire-client/src/client.rs b/chainfire/chainfire-client/src/client.rs index 9a645f2..55f0d08 100644 --- a/chainfire/chainfire-client/src/client.rs +++ b/chainfire/chainfire-client/src/client.rs @@ -3,24 +3,13 @@ use crate::error::{ClientError, Result}; use crate::watch::WatchHandle; use chainfire_proto::proto::{ - cluster_client::ClusterClient, - compare, - kv_client::KvClient, - request_op, - response_op, - watch_client::WatchClient, - Compare, - DeleteRangeRequest, - MemberAddRequest, - PutRequest, - RangeRequest, - RequestOp, - StatusRequest, - TxnRequest, + cluster_client::ClusterClient, compare, kv_client::KvClient, request_op, response_op, + watch_client::WatchClient, Compare, DeleteRangeRequest, MemberAddRequest, PutRequest, + RangeRequest, RequestOp, StatusRequest, TxnRequest, }; use std::time::Duration; -use tonic::Code; use tonic::transport::Channel; +use tonic::Code; use tracing::{debug, warn}; /// Chainfire client @@ -64,7 +53,9 @@ impl Client { } } - Err(last_error.unwrap_or_else(|| ClientError::Connection("no Chainfire endpoints configured".to_string()))) + Err(last_error.unwrap_or_else(|| { + ClientError::Connection("no Chainfire endpoints configured".to_string()) + })) } async fn with_kv_retry(&mut self, mut op: F) -> Result @@ -88,14 +79,17 @@ impl Client { "retrying Chainfire KV RPC on alternate endpoint" ); last_status = Some(status); - self.recover_after_status(last_status.as_ref().unwrap()).await?; + self.recover_after_status(last_status.as_ref().unwrap()) + .await?; tokio::time::sleep(retry_delay(attempt)).await; } Err(status) => return Err(status.into()), } } - Err(last_status.unwrap_or_else(|| tonic::Status::unavailable("Chainfire KV retry exhausted")).into()) + Err(last_status + .unwrap_or_else(|| tonic::Status::unavailable("Chainfire KV retry exhausted")) + .into()) } async fn with_cluster_retry(&mut self, mut op: F) -> Result @@ -119,14 +113,17 @@ impl Client { "retrying Chainfire cluster RPC on alternate endpoint" ); last_status = Some(status); - self.recover_after_status(last_status.as_ref().unwrap()).await?; + self.recover_after_status(last_status.as_ref().unwrap()) + .await?; tokio::time::sleep(retry_delay(attempt)).await; } Err(status) => return Err(status.into()), } } - Err(last_status.unwrap_or_else(|| tonic::Status::unavailable("Chainfire cluster retry exhausted")).into()) + Err(last_status + .unwrap_or_else(|| tonic::Status::unavailable("Chainfire cluster retry exhausted")) + .into()) } async fn recover_after_status(&mut self, status: &tonic::Status) -> Result<()> { @@ -150,7 +147,9 @@ impl Client { let endpoint = self .endpoints .get(index) - .ok_or_else(|| ClientError::Connection(format!("invalid Chainfire endpoint index {index}")))? + .ok_or_else(|| { + ClientError::Connection(format!("invalid Chainfire endpoint index {index}")) + })? .clone(); let (channel, kv, cluster) = connect_endpoint(&endpoint).await?; self.current_endpoint = index; @@ -182,7 +181,11 @@ impl Client { match cluster.status(StatusRequest {}).await { Ok(response) => { let status = response.into_inner(); - let member_id = status.header.as_ref().map(|header| header.member_id).unwrap_or(0); + let member_id = status + .header + .as_ref() + .map(|header| header.member_id) + .unwrap_or(0); if status.leader != 0 && status.leader == member_id { return Ok(Some(index)); } @@ -232,10 +235,7 @@ impl Client { /// Get a value by key pub async fn get(&mut self, key: impl AsRef<[u8]>) -> Result>> { - Ok(self - .get_with_revision(key) - .await? - .map(|(value, _)| value)) + Ok(self.get_with_revision(key).await?.map(|(value, _)| value)) } /// Get a value by key along with its current revision @@ -263,13 +263,14 @@ impl Client { }) .await?; - Ok(resp.kvs.into_iter().next().map(|kv| (kv.value, kv.mod_revision as u64))) + Ok(resp + .kvs + .into_iter() + .next() + .map(|kv| (kv.value, kv.mod_revision as u64))) } /// Put a key-value pair only if the key's mod_revision matches. - /// - /// This is a best-effort compare-and-set. The server may not return - /// a reliable success flag, so callers should treat this as "attempted". pub async fn put_if_revision( &mut self, key: impl AsRef<[u8]>, @@ -277,6 +278,7 @@ impl Client { expected_mod_revision: u64, ) -> Result<()> { let key_bytes = key.as_ref().to_vec(); + let key_display = String::from_utf8_lossy(&key_bytes).to_string(); let compare = Compare { result: compare::CompareResult::Equal as i32, target: compare::CompareTarget::Mod as i32, @@ -288,29 +290,63 @@ impl Client { let put_op = RequestOp { request: Some(request_op::Request::RequestPut(PutRequest { - key: key_bytes, + key: key_bytes.clone(), value: value.as_ref().to_vec(), lease: 0, prev_kv: false, })), }; - self.with_kv_retry(|mut kv| { - let compare = compare.clone(); - let put_op = put_op.clone(); - async move { - kv.txn(TxnRequest { - compare: vec![compare], - success: vec![put_op], - failure: vec![], - }) - .await - .map(|resp| resp.into_inner()) - } - }) - .await?; + let read_on_fail = RequestOp { + request: Some(request_op::Request::RequestRange(RangeRequest { + key: key_bytes.clone(), + range_end: vec![], + limit: 1, + revision: 0, + keys_only: false, + count_only: false, + serializable: false, + })), + }; - Ok(()) + let resp = self + .with_kv_retry(|mut kv| { + let compare = compare.clone(); + let put_op = put_op.clone(); + let read_on_fail = read_on_fail.clone(); + async move { + kv.txn(TxnRequest { + compare: vec![compare], + success: vec![put_op], + failure: vec![read_on_fail], + }) + .await + .map(|resp| resp.into_inner()) + } + }) + .await?; + + if resp.succeeded { + return Ok(()); + } + + let current_revision = resp + .responses + .into_iter() + .filter_map(|op| match op.response { + Some(response_op::Response::ResponseRange(range)) => range + .kvs + .into_iter() + .next() + .map(|kv| kv.mod_revision as u64), + _ => None, + }) + .next() + .unwrap_or(0); + + Err(ClientError::Conflict(format!( + "mod_revision mismatch for key {key_display}: expected {expected_mod_revision}, current {current_revision}" + ))) } /// Get a value as string @@ -341,7 +377,10 @@ impl Client { } /// Get all keys with a prefix - pub async fn get_prefix(&mut self, prefix: impl AsRef<[u8]>) -> Result, Vec)>> { + pub async fn get_prefix( + &mut self, + prefix: impl AsRef<[u8]>, + ) -> Result, Vec)>> { let prefix = prefix.as_ref(); let range_end = prefix_end(prefix); @@ -404,12 +443,11 @@ impl Client { .map(|kv| (kv.key, kv.value, kv.mod_revision as u64)) .collect(); let next_key = if more { - kvs.last() - .map(|(k, _, _)| { - let mut nk = k.clone(); - nk.push(0); - nk - }) + kvs.last().map(|(k, _, _)| { + let mut nk = k.clone(); + nk.push(0); + nk + }) } else { None }; @@ -451,12 +489,11 @@ impl Client { .map(|kv| (kv.key, kv.value, kv.mod_revision as u64)) .collect(); let next_key = if more { - kvs.last() - .map(|(k, _, _)| { - let mut nk = k.clone(); - nk.push(0); - nk - }) + kvs.last().map(|(k, _, _)| { + let mut nk = k.clone(); + nk.push(0); + nk + }) } else { None }; @@ -519,11 +556,7 @@ impl Client { .await?; if resp.succeeded { - let new_version = resp - .header - .as_ref() - .map(|h| h.revision as u64) - .unwrap_or(0); + let new_version = resp.header.as_ref().map(|h| h.revision as u64).unwrap_or(0); return Ok(CasOutcome { success: true, current_version: new_version, @@ -536,11 +569,9 @@ impl Client { .responses .into_iter() .filter_map(|op| match op.response { - Some(response_op::Response::ResponseRange(r)) => r - .kvs - .into_iter() - .next() - .map(|kv| kv.mod_revision as u64), + Some(response_op::Response::ResponseRange(r)) => { + r.kvs.into_iter().next().map(|kv| kv.mod_revision as u64) + } _ => None, }) .next() @@ -594,7 +625,12 @@ impl Client { /// /// # Returns /// The node ID of the added member - pub async fn member_add(&mut self, node_id: u64, peer_url: impl AsRef, is_learner: bool) -> Result { + pub async fn member_add( + &mut self, + node_id: u64, + peer_url: impl AsRef, + is_learner: bool, + ) -> Result { let peer_url = peer_url.as_ref().to_string(); let resp = self .with_cluster_retry(|mut cluster| { @@ -660,7 +696,9 @@ fn parse_endpoints(input: &str) -> Result> { .collect(); if endpoints.is_empty() { - return Err(ClientError::Connection("no Chainfire endpoints configured".to_string())); + return Err(ClientError::Connection( + "no Chainfire endpoints configured".to_string(), + )); } Ok(endpoints) @@ -674,7 +712,9 @@ fn normalize_endpoint(endpoint: &str) -> String { } } -async fn connect_endpoint(endpoint: &str) -> Result<(Channel, KvClient, ClusterClient)> { +async fn connect_endpoint( + endpoint: &str, +) -> Result<(Channel, KvClient, ClusterClient)> { let channel = Channel::from_shared(endpoint.to_string()) .map_err(|e| ClientError::Connection(e.to_string()))? .connect() @@ -693,7 +733,11 @@ fn retry_delay(attempt: usize) -> Duration { fn is_retryable_status(status: &tonic::Status) -> bool { matches!( status.code(), - Code::Unavailable | Code::DeadlineExceeded | Code::Internal | Code::Aborted | Code::FailedPrecondition + Code::Unavailable + | Code::DeadlineExceeded + | Code::Internal + | Code::Aborted + | Code::FailedPrecondition ) || retryable_message(status.message()) } @@ -734,8 +778,14 @@ mod tests { #[test] fn normalize_endpoint_adds_http_scheme() { - assert_eq!(normalize_endpoint("127.0.0.1:2379"), "http://127.0.0.1:2379"); - assert_eq!(normalize_endpoint("http://127.0.0.1:2379"), "http://127.0.0.1:2379"); + assert_eq!( + normalize_endpoint("127.0.0.1:2379"), + "http://127.0.0.1:2379" + ); + assert_eq!( + normalize_endpoint("http://127.0.0.1:2379"), + "http://127.0.0.1:2379" + ); } #[test] diff --git a/chainfire/chainfire-client/src/error.rs b/chainfire/chainfire-client/src/error.rs index 9329c3c..ff547e2 100644 --- a/chainfire/chainfire-client/src/error.rs +++ b/chainfire/chainfire-client/src/error.rs @@ -31,4 +31,8 @@ pub enum ClientError { /// Internal error #[error("Internal error: {0}")] Internal(String), + + /// Compare-and-set conflict + #[error("Conflict: {0}")] + Conflict(String), } diff --git a/deployer/crates/node-agent/src/agent.rs b/deployer/crates/node-agent/src/agent.rs index ae668d5..13f8319 100644 --- a/deployer/crates/node-agent/src/agent.rs +++ b/deployer/crates/node-agent/src/agent.rs @@ -5,7 +5,7 @@ use std::process::Stdio; use std::time::Duration; use anyhow::{Context, Result}; -use chainfire_client::Client; +use chainfire_client::{Client, ClientError}; use chrono::{DateTime, Utc}; use deployer_types::{ContainerSpec, HealthCheckSpec, ProcessSpec, ServiceInstanceSpec}; use serde::{Deserialize, Serialize}; @@ -254,7 +254,11 @@ impl Agent { } } - fn render_container_spec(&self, spec: &ContainerSpec, inst: &ServiceInstanceSpec) -> ContainerSpec { + fn render_container_spec( + &self, + spec: &ContainerSpec, + inst: &ServiceInstanceSpec, + ) -> ContainerSpec { let mut rendered = spec.clone(); rendered.image = self.render_template_value(&rendered.image, inst); rendered.command = rendered @@ -283,10 +287,7 @@ impl Agent { rendered } - fn desired_process_spec( - &self, - inst: &ServiceInstanceSpec, - ) -> Option { + fn desired_process_spec(&self, inst: &ServiceInstanceSpec) -> Option { match (&inst.container, &inst.process) { (Some(container), maybe_process) => { if maybe_process.is_some() { @@ -309,6 +310,93 @@ impl Agent { } } + fn apply_instance_health_fields( + inst_value: &mut Value, + started_at: &DateTime, + heartbeat_at: &DateTime, + health_status: &str, + ) { + let Some(obj) = inst_value.as_object_mut() else { + return; + }; + + let observed_at = Value::String(started_at.to_rfc3339()); + match obj.get_mut("observed_at") { + Some(slot) if slot.is_null() => *slot = observed_at, + Some(_) => {} + None => { + obj.insert("observed_at".to_string(), observed_at); + } + } + + obj.insert( + "state".to_string(), + Value::String(health_status.to_string()), + ); + obj.insert( + "last_heartbeat".to_string(), + Value::String(heartbeat_at.to_rfc3339()), + ); + } + + async fn persist_instance_health( + &self, + client: &mut Client, + key: &[u8], + inst: &ServiceInstanceSpec, + mut inst_value: Value, + mut mod_revision: u64, + started_at: &DateTime, + health_status: &str, + ) -> Result<()> { + for attempt in 0..3 { + let heartbeat_at = Utc::now(); + Self::apply_instance_health_fields( + &mut inst_value, + started_at, + &heartbeat_at, + health_status, + ); + + let updated = serde_json::to_vec(&inst_value)?; + match client.put_if_revision(key, &updated, mod_revision).await { + Ok(()) => return Ok(()), + Err(ClientError::Conflict(error)) if attempt < 2 => { + warn!( + service = %inst.service, + instance_id = %inst.instance_id, + mod_revision, + attempt = attempt + 1, + error = %error, + "instance health update raced with another writer; retrying with fresh revision" + ); + + let Some((latest_bytes, latest_revision)) = + client.get_with_revision(key).await? + else { + warn!( + service = %inst.service, + instance_id = %inst.instance_id, + "instance disappeared while retrying health update" + ); + return Ok(()); + }; + + inst_value = serde_json::from_slice(&latest_bytes).with_context(|| { + format!( + "failed to parse refreshed instance JSON for {}", + inst.instance_id + ) + })?; + mod_revision = latest_revision; + } + Err(error) => return Err(error.into()), + } + } + + Ok(()) + } + /// ローカルファイル (/etc/photoncloud/instances.json) から ServiceInstance 定義を読み、 /// Chainfire 上の `photoncloud/clusters/{cluster_id}/instances/{service}/{instance_id}` に upsert する。 async fn sync_local_instances(&self, client: &mut Client) -> Result<()> { @@ -348,7 +436,9 @@ impl Agent { { for preserve_key in ["state", "last_heartbeat", "observed_at"] { if let Some(value) = existing_obj.get(preserve_key) { - desired_obj.entry(preserve_key.to_string()).or_insert(value.clone()); + desired_obj + .entry(preserve_key.to_string()) + .or_insert(value.clone()); } } } @@ -398,7 +488,6 @@ impl Agent { // Desired Stateに基づいてプロセスを管理 for (service, instance_id, proc_spec) in desired_instances { - if let Some(process) = self.process_manager.get_mut(&service, &instance_id) { if process.spec != proc_spec { process.spec = proc_spec; @@ -526,7 +615,7 @@ impl Agent { let mut seen = HashSet::new(); for (key, value, mod_revision) in kvs { - let mut inst_value: Value = match serde_json::from_slice(&value) { + let inst_value: Value = match serde_json::from_slice(&value) { Ok(v) => v, Err(e) => { warn!(error = %e, "failed to parse instance json"); @@ -582,29 +671,25 @@ impl Agent { "healthy".to_string() // デフォルトはhealthy }; - // Chainfire上のServiceInstanceに状態を反映 - if let Some(obj) = inst_value.as_object_mut() { - obj.entry("observed_at".to_string()) - .or_insert_with(|| Value::String(started_at.to_rfc3339())); - obj.insert( - "state".to_string(), - Value::String(health_status.clone()), + if let Err(e) = self + .persist_instance_health( + client, + &key, + &inst, + inst_value, + mod_revision, + &started_at, + &health_status, + ) + .await + { + warn!( + service = %inst.service, + instance_id = %inst.instance_id, + mod_revision, + error = ?e, + "failed to update instance health status" ); - obj.insert( - "last_heartbeat".to_string(), - Value::String(now.to_rfc3339()), - ); - } - - let updated = serde_json::to_vec(&inst_value)?; - if let Err(e) = client.put_if_revision(&key, &updated, mod_revision).await { - warn!( - service = %inst.service, - instance_id = %inst.instance_id, - mod_revision, - error = ?e, - "failed to update instance health status" - ); } info!( @@ -615,8 +700,7 @@ impl Agent { ); } - self.next_health_checks - .retain(|key, _| seen.contains(key)); + self.next_health_checks.retain(|key, _| seen.contains(key)); Ok(()) } @@ -732,7 +816,10 @@ mod tests { assert_eq!(rendered.args[2], "18080"); assert_eq!(rendered.args[4], "127.0.0.2"); assert_eq!(rendered.working_dir.as_deref(), Some("/srv/api")); - assert_eq!(rendered.env.get("INSTANCE").map(String::as_str), Some("api-node01")); + assert_eq!( + rendered.env.get("INSTANCE").map(String::as_str), + Some("api-node01") + ); } #[test] @@ -779,4 +866,59 @@ mod tests { .insert(key, Utc::now() - chrono::Duration::seconds(1)); assert!(agent.health_check_due(&instance, &health_check)); } + + #[test] + fn test_apply_instance_health_fields_replaces_nulls_and_preserves_observed_at() { + let started_at = DateTime::parse_from_rfc3339("2026-03-31T03:00:00Z") + .unwrap() + .with_timezone(&Utc); + let heartbeat_at = DateTime::parse_from_rfc3339("2026-03-31T03:00:05Z") + .unwrap() + .with_timezone(&Utc); + let original_observed_at = "2026-03-31T02:59:59Z"; + + let mut null_observed = serde_json::json!({ + "observed_at": null, + "state": null, + "last_heartbeat": null + }); + Agent::apply_instance_health_fields( + &mut null_observed, + &started_at, + &heartbeat_at, + "healthy", + ); + assert_eq!( + null_observed.get("observed_at").and_then(Value::as_str), + Some("2026-03-31T03:00:00+00:00") + ); + assert_eq!( + null_observed.get("state").and_then(Value::as_str), + Some("healthy") + ); + assert_eq!( + null_observed.get("last_heartbeat").and_then(Value::as_str), + Some("2026-03-31T03:00:05+00:00") + ); + + let mut existing_observed = serde_json::json!({ + "observed_at": original_observed_at, + "state": "starting", + "last_heartbeat": null + }); + Agent::apply_instance_health_fields( + &mut existing_observed, + &started_at, + &heartbeat_at, + "healthy", + ); + assert_eq!( + existing_observed.get("observed_at").and_then(Value::as_str), + Some(original_observed_at) + ); + assert_eq!( + existing_observed.get("state").and_then(Value::as_str), + Some("healthy") + ); + } } diff --git a/flaredb/crates/flaredb-client/src/client.rs b/flaredb/crates/flaredb-client/src/client.rs index 9b11273..fdb48b2 100644 --- a/flaredb/crates/flaredb-client/src/client.rs +++ b/flaredb/crates/flaredb-client/src/client.rs @@ -7,10 +7,10 @@ use flaredb_proto::kvrpc::{ RawScanRequest, }; use flaredb_proto::pdpb::Store; +use serde::Deserialize; use std::collections::HashMap; use std::sync::Arc; use std::time::{Instant, SystemTime, UNIX_EPOCH}; -use serde::Deserialize; use tokio::sync::Mutex; use tonic::transport::Channel; @@ -61,6 +61,74 @@ struct ChainfireRouteSnapshot { fetched_at: Instant, } +#[derive(Debug, Clone)] +struct ResolvedRoute { + region: Region, + leader: Store, + candidate_addrs: Vec, +} + +fn push_unique_addr(addrs: &mut Vec, addr: &str) { + if !addrs.iter().any(|existing| existing == addr) { + addrs.push(addr.to_string()); + } +} + +fn resolve_chainfire_route_from_snapshot( + key: &[u8], + snapshot: &ChainfireRouteSnapshot, +) -> Result { + let region = snapshot + .regions + .iter() + .find(|region| { + let start_ok = region.start_key.is_empty() || key >= region.start_key.as_slice(); + let end_ok = region.end_key.is_empty() || key < region.end_key.as_slice(); + start_ok && end_ok + }) + .cloned() + .ok_or_else(|| tonic::Status::not_found("region not found"))?; + + let mut candidate_addrs = Vec::new(); + let mut selected_store = snapshot.stores.get(®ion.leader_id).cloned(); + + if let Some(store) = &selected_store { + push_unique_addr(&mut candidate_addrs, &store.addr); + } + + for peer_id in ®ion.peers { + if let Some(store) = snapshot.stores.get(peer_id) { + if selected_store.is_none() { + selected_store = Some(store.clone()); + } + push_unique_addr(&mut candidate_addrs, &store.addr); + } + } + + let selected_store = selected_store + .ok_or_else(|| tonic::Status::not_found("region peer store not found"))?; + if candidate_addrs.is_empty() { + return Err(tonic::Status::not_found( + "region has no candidate store addresses", + )); + } + + Ok(ResolvedRoute { + region: Region { + id: region.id, + start_key: region.start_key, + end_key: region.end_key, + peers: region.peers, + leader_id: region.leader_id, + }, + leader: Store { + id: selected_store.id, + addr: selected_store.addr, + }, + candidate_addrs, + }) +} + impl RdbClient { const ROUTE_RETRY_LIMIT: usize = 12; const ROUTE_RETRY_BASE_DELAY_MS: u64 = 100; @@ -116,11 +184,9 @@ impl RdbClient { .await; clients = Some(match probe { - Err(status) if status.code() == tonic::Code::Unimplemented => ( - None, - None, - Some(ChainfireKvClient::new(pd_channel)), - ), + Err(status) if status.code() == tonic::Code::Unimplemented => { + (None, None, Some(ChainfireKvClient::new(pd_channel))) + } _ => ( Some(TsoClient::new(pd_channel.clone())), Some(PdClient::new(pd_channel)), @@ -134,13 +200,11 @@ impl RdbClient { } else if let Some(error) = last_error { return Err(error); } else { - return Err( - Channel::from_shared("http://127.0.0.1:1".to_string()) - .unwrap() - .connect() - .await - .expect_err("unreachable fallback endpoint should fail to connect"), - ); + return Err(Channel::from_shared("http://127.0.0.1:1".to_string()) + .unwrap() + .connect() + .await + .expect_err("unreachable fallback endpoint should fail to connect")); } }; @@ -187,13 +251,11 @@ impl RdbClient { } else if let Some(error) = last_error { return Err(error); } else { - return Err( - Channel::from_shared("http://127.0.0.1:1".to_string()) - .unwrap() - .connect() - .await - .expect_err("unreachable fallback endpoint should fail to connect"), - ); + return Err(Channel::from_shared("http://127.0.0.1:1".to_string()) + .unwrap() + .connect() + .await + .expect_err("unreachable fallback endpoint should fail to connect")); }; let channel = channel.expect("direct connect should produce a channel when selected"); @@ -219,7 +281,13 @@ impl RdbClient { } if let Some(chainfire_kv_client) = &self.chainfire_kv_client { - return self.resolve_addr_via_chainfire(key, chainfire_kv_client.clone()).await; + let route = self + .resolve_route_via_chainfire(key, chainfire_kv_client.clone(), false) + .await?; + self.region_cache + .update(route.region.clone(), route.leader.clone()) + .await; + return Ok(route.leader.addr); } if let Some(pd_client) = &self.pd_client { @@ -244,7 +312,13 @@ impl RdbClient { self.invalidate_chainfire_route_cache().await; if let Some(chainfire_kv_client) = &self.chainfire_kv_client { - return self.resolve_addr_via_chainfire(key, chainfire_kv_client.clone()).await; + let route = self + .resolve_route_via_chainfire(key, chainfire_kv_client.clone(), true) + .await?; + self.region_cache + .update(route.region.clone(), route.leader.clone()) + .await; + return Ok(route.leader.addr); } if let Some(pd_client) = &self.pd_client { @@ -310,53 +384,21 @@ impl RdbClient { Ok(snapshot) } - fn resolve_addr_from_chainfire_snapshot( - &self, - key: &[u8], - snapshot: &ChainfireRouteSnapshot, - ) -> Result<(Region, Store), tonic::Status> { - let region = snapshot - .regions - .iter() - .find(|region| { - let start_ok = region.start_key.is_empty() || key >= region.start_key.as_slice(); - let end_ok = region.end_key.is_empty() || key < region.end_key.as_slice(); - start_ok && end_ok - }) - .cloned() - .ok_or_else(|| tonic::Status::not_found("region not found"))?; - - let leader = snapshot - .stores - .get(®ion.leader_id) - .cloned() - .ok_or_else(|| tonic::Status::not_found("leader store not found"))?; - - Ok(( - Region { - id: region.id, - start_key: region.start_key, - end_key: region.end_key, - peers: region.peers, - leader_id: region.leader_id, - }, - Store { - id: leader.id, - addr: leader.addr, - }, - )) - } - async fn with_routed_addr(&self, key: &[u8], mut op: F) -> Result where F: FnMut(String) -> Fut, Fut: Future>, { - let mut addr = self.resolve_addr(key).await?; + let mut candidate_addrs = self.resolve_route_candidates(key, false).await?; + let mut candidate_idx = 0usize; let mut refreshed = false; let mut last_status = None; for attempt in 0..Self::ROUTE_RETRY_LIMIT { + let addr = candidate_addrs + .get(candidate_idx) + .cloned() + .ok_or_else(|| tonic::Status::internal("routing candidate list exhausted"))?; match tokio::time::timeout(Self::ROUTED_RPC_TIMEOUT, op(addr.clone())).await { Err(_) => { Self::evict_channel_from_map(&self.channels, &addr).await; @@ -366,10 +408,19 @@ impl RdbClient { Self::ROUTED_RPC_TIMEOUT.as_millis() )); + if candidate_idx + 1 < candidate_addrs.len() { + candidate_idx += 1; + last_status = Some(status); + tokio::time::sleep(Self::retry_delay(attempt)).await; + continue; + } + if !refreshed && self.direct_addr.is_none() { refreshed = true; - if let Ok(fresh_addr) = self.resolve_addr_uncached(key).await { - addr = fresh_addr; + if let Ok(fresh_candidates) = self.resolve_route_candidates(key, true).await + { + candidate_addrs = fresh_candidates; + candidate_idx = 0; last_status = Some(status); tokio::time::sleep(Self::retry_delay(attempt)).await; continue; @@ -391,20 +442,33 @@ impl RdbClient { .override_store_addr(key, redirect_addr.clone()) .await; if redirect_addr != addr { - addr = redirect_addr; + candidate_addrs.retain(|candidate| candidate != &redirect_addr); + candidate_addrs.insert(0, redirect_addr); + candidate_idx = 0; last_status = Some(status); tokio::time::sleep(Self::retry_delay(attempt)).await; continue; } } + if (transport_error || Self::is_retryable_route_error(&status)) + && candidate_idx + 1 < candidate_addrs.len() + { + candidate_idx += 1; + last_status = Some(status); + tokio::time::sleep(Self::retry_delay(attempt)).await; + continue; + } + if !refreshed && self.direct_addr.is_none() && Self::is_retryable_route_error(&status) { refreshed = true; - if let Ok(fresh_addr) = self.resolve_addr_uncached(key).await { - addr = fresh_addr; + if let Ok(fresh_candidates) = self.resolve_route_candidates(key, true).await + { + candidate_addrs = fresh_candidates; + candidate_idx = 0; last_status = Some(status); tokio::time::sleep(Self::retry_delay(attempt)).await; continue; @@ -425,13 +489,54 @@ impl RdbClient { return Err(status); } - Ok(Ok(value)) => return Ok(value), + Ok(Ok(value)) => { + if candidate_idx > 0 { + self.region_cache.override_store_addr(key, addr).await; + } + return Ok(value); + } } } Err(last_status.unwrap_or_else(|| tonic::Status::internal("routing retry exhausted"))) } + async fn resolve_route_candidates( + &self, + key: &[u8], + force_refresh: bool, + ) -> Result, tonic::Status> { + if let Some(addr) = &self.direct_addr { + return Ok(vec![addr.clone()]); + } + + if !force_refresh { + if let Some(addr) = self.region_cache.get_store_addr(key).await { + return Ok(vec![addr]); + } + } else { + self.region_cache.clear().await; + self.invalidate_chainfire_route_cache().await; + } + + if let Some(chainfire_kv_client) = &self.chainfire_kv_client { + let route = self + .resolve_route_via_chainfire(key, chainfire_kv_client.clone(), force_refresh) + .await?; + self.region_cache + .update(route.region.clone(), route.leader.clone()) + .await; + return Ok(route.candidate_addrs); + } + + let addr = if force_refresh { + self.resolve_addr_uncached(key).await? + } else { + self.resolve_addr(key).await? + }; + Ok(vec![addr]) + } + fn is_retryable_route_error(status: &tonic::Status) -> bool { if !matches!( status.code(), @@ -454,9 +559,7 @@ impl RdbClient { } fn retry_delay(attempt: usize) -> Duration { - Duration::from_millis( - Self::ROUTE_RETRY_BASE_DELAY_MS.saturating_mul((attempt as u64) + 1), - ) + Duration::from_millis(Self::ROUTE_RETRY_BASE_DELAY_MS.saturating_mul((attempt as u64) + 1)) } fn is_transport_error(status: &tonic::Status) -> bool { @@ -691,7 +794,11 @@ impl RdbClient { .into_iter() .map(|kv| (kv.key, kv.value, kv.version)) .collect(); - let next = if resp.has_more { Some(resp.next_key) } else { None }; + let next = if resp.has_more { + Some(resp.next_key) + } else { + None + }; Ok((entries, next)) } }) @@ -727,20 +834,25 @@ impl RdbClient { .await } - async fn resolve_addr_via_chainfire( + async fn resolve_route_via_chainfire( &self, key: &[u8], kv_client: ChainfireKvClient, - ) -> Result { - for force_refresh in [false, true] { + force_refresh: bool, + ) -> Result { + if force_refresh { let snapshot = self - .chainfire_route_snapshot(kv_client.clone(), force_refresh) + .chainfire_route_snapshot(kv_client, true) .await?; - if let Ok((region, leader)) = - self.resolve_addr_from_chainfire_snapshot(key, &snapshot) - { - self.region_cache.update(region, leader.clone()).await; - return Ok(leader.addr); + return resolve_chainfire_route_from_snapshot(key, &snapshot); + } + + for refresh in [false, true] { + let snapshot = self + .chainfire_route_snapshot(kv_client.clone(), refresh) + .await?; + if let Ok(route) = resolve_chainfire_route_from_snapshot(key, &snapshot) { + return Ok(route); } } @@ -833,7 +945,13 @@ async fn list_chainfire_regions( #[cfg(test)] mod tests { - use super::{RdbClient, normalize_transport_addr, parse_transport_endpoints}; + use super::{ + normalize_transport_addr, parse_transport_endpoints, + resolve_chainfire_route_from_snapshot, ChainfireRegionInfo, ChainfireRouteSnapshot, + ChainfireStoreInfo, RdbClient, + }; + use std::collections::HashMap; + use std::time::Instant; #[test] fn unknown_transport_errors_are_treated_as_retryable_routes() { @@ -864,4 +982,77 @@ mod tests { "10.0.0.1:2479".to_string() ); } + + #[test] + fn chainfire_routes_try_leader_then_other_peers() { + let snapshot = ChainfireRouteSnapshot { + stores: HashMap::from([ + ( + 1, + ChainfireStoreInfo { + id: 1, + addr: "10.0.0.1:2479".to_string(), + }, + ), + ( + 2, + ChainfireStoreInfo { + id: 2, + addr: "10.0.0.2:2479".to_string(), + }, + ), + ( + 3, + ChainfireStoreInfo { + id: 3, + addr: "10.0.0.3:2479".to_string(), + }, + ), + ]), + regions: vec![ChainfireRegionInfo { + id: 1, + start_key: Vec::new(), + end_key: Vec::new(), + peers: vec![1, 2, 3], + leader_id: 2, + }], + fetched_at: Instant::now(), + }; + + let route = resolve_chainfire_route_from_snapshot(b"tenant/key", &snapshot).unwrap(); + assert_eq!(route.leader.id, 2); + assert_eq!( + route.candidate_addrs, + vec![ + "10.0.0.2:2479".to_string(), + "10.0.0.1:2479".to_string(), + "10.0.0.3:2479".to_string(), + ] + ); + } + + #[test] + fn chainfire_routes_fall_back_when_reported_leader_store_is_missing() { + let snapshot = ChainfireRouteSnapshot { + stores: HashMap::from([( + 1, + ChainfireStoreInfo { + id: 1, + addr: "10.0.0.1:2479".to_string(), + }, + )]), + regions: vec![ChainfireRegionInfo { + id: 1, + start_key: Vec::new(), + end_key: Vec::new(), + peers: vec![1, 2], + leader_id: 2, + }], + fetched_at: Instant::now(), + }; + + let route = resolve_chainfire_route_from_snapshot(b"tenant/key", &snapshot).unwrap(); + assert_eq!(route.leader.id, 1); + assert_eq!(route.candidate_addrs, vec!["10.0.0.1:2479".to_string()]); + } } diff --git a/flaredb/crates/flaredb-server/src/main.rs b/flaredb/crates/flaredb-server/src/main.rs index 6264357..dcddcf0 100644 --- a/flaredb/crates/flaredb-server/src/main.rs +++ b/flaredb/crates/flaredb-server/src/main.rs @@ -1,3 +1,4 @@ +use anyhow::Result; use clap::Parser; use flaredb_proto::kvrpc::kv_cas_server::KvCasServer; use flaredb_proto::kvrpc::kv_raw_server::KvRawServer; @@ -17,8 +18,7 @@ use tokio::time::{sleep, Duration}; use tonic::transport::{Certificate, Channel, Identity, Server, ServerTlsConfig}; use tonic_health::server::health_reporter; use tracing::{info, warn}; // Import warn -use tracing_subscriber::EnvFilter; -use anyhow::Result; // Import anyhow +use tracing_subscriber::EnvFilter; // Import anyhow mod heartbeat; mod merkle; @@ -273,12 +273,13 @@ async fn main() -> Result<(), Box> { let store = Arc::new(store::Store::new( server_config.store_id, engine.clone(), - server_config.clone(), // Pass server_config + server_config.clone(), // Pass server_config namespace_manager.clone(), // Pass namespace manager peer_addrs.clone(), )); - let service = service::KvServiceImpl::new(engine.clone(), namespace_manager.clone(), store.clone()); + let service = + service::KvServiceImpl::new(engine.clone(), namespace_manager.clone(), store.clone()); let raft_service = raft_service::RaftServiceImpl::new(store.clone(), server_config.store_id); let pd_endpoints = server_config.resolved_pd_endpoints(); @@ -389,6 +390,7 @@ async fn main() -> Result<(), Box> { let store_id = server_config.store_id; let server_addr_string = server_config.addr.to_string(); tokio::spawn(async move { + let mut last_reported_leaders: HashMap = HashMap::new(); let client = Arc::new(Mutex::new( ChainfirePdClient::connect_any(&pd_endpoints_for_task) .await @@ -399,14 +401,14 @@ async fn main() -> Result<(), Box> { let mut guard = client.lock().await; if let Some(ref mut c) = *guard { // Send heartbeat - let heartbeat_ok = - match c.heartbeat(store_id, server_addr_string.clone()).await { - Ok(_) => true, - Err(e) => { - warn!("Heartbeat failed: {}", e); - false - } - }; + let heartbeat_ok = match c.heartbeat(store_id, server_addr_string.clone()).await + { + Ok(_) => true, + Err(e) => { + warn!("Heartbeat failed: {}", e); + false + } + }; // If heartbeat failed, try to reconnect on next cycle if !heartbeat_ok { @@ -417,15 +419,38 @@ async fn main() -> Result<(), Box> { // Report observed leader status so routing metadata converges // even when followers are the first nodes to notice a leadership change. let region_ids = store_clone.list_region_ids().await; + let mut observed_regions = HashMap::new(); for region_id in region_ids { if let Some(node) = store_clone.get_raft_node(region_id).await { if let Some(observed_leader) = node.leader_id().await { - if let Err(e) = c.report_leader(region_id, observed_leader).await { - warn!("Report leader failed: {}", e); + observed_regions.insert(region_id, observed_leader); + let previous = last_reported_leaders.get(®ion_id).copied(); + if previous != Some(observed_leader) { + info!( + region_id, + previous_leader = ?previous, + observed_leader, + "Reporting FlareDB region leader to PD" + ); + } + match c.report_leader(region_id, observed_leader).await { + Ok(_) => { + last_reported_leaders.insert(region_id, observed_leader); + } + Err(e) => { + warn!( + region_id, + observed_leader, + error = %e, + "Report leader failed" + ); + } } } } } + last_reported_leaders + .retain(|region_id, _| observed_regions.contains_key(region_id)); // Refresh regions from PD (from cache, updated via watch) let regions = c.list_regions().await; @@ -577,12 +602,9 @@ async fn main() -> Result<(), Box> { let tls = if tls_config.require_client_cert { info!("mTLS enabled, requiring client certificates"); - let ca_cert = tokio::fs::read( - tls_config - .ca_file - .as_ref() - .ok_or_else(|| anyhow::anyhow!("ca_file required when require_client_cert=true"))?, - ) + let ca_cert = tokio::fs::read(tls_config.ca_file.as_ref().ok_or_else(|| { + anyhow::anyhow!("ca_file required when require_client_cert=true") + })?) .await .map_err(|e| anyhow::anyhow!("Failed to read CA file: {}", e))?; let ca = Certificate::from_pem(ca_cert); @@ -650,6 +672,8 @@ async fn main() -> Result<(), Box> { fn init_logging(level: &str) { tracing_subscriber::fmt() - .with_env_filter(EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(level))) + .with_env_filter( + EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(level)), + ) .init(); } diff --git a/nix/test-cluster/run-cluster.sh b/nix/test-cluster/run-cluster.sh index 2fbc53c..91a6db9 100755 --- a/nix/test-cluster/run-cluster.sh +++ b/nix/test-cluster/run-cluster.sh @@ -2633,6 +2633,91 @@ try_get_volume_json() { 127.0.0.1:${vm_port} plasmavmc.v1.VolumeService/GetVolume } +start_plasmavmc_vm_watch() { + local node="$1" + local proto_root="$2" + local token="$3" + local org_id="$4" + local project_id="$5" + local vm_id="$6" + local output_path="$7" + + ssh_node_script "${node}" "${proto_root}" "${token}" "${org_id}" "${project_id}" "${vm_id}" "${output_path}" <<'EOS' +set -euo pipefail +proto_root="$1" +token="$2" +org_id="$3" +project_id="$4" +vm_id="$5" +output_path="$6" +rm -f "${output_path}" "${output_path}.pid" "${output_path}.stderr" +nohup timeout 600 grpcurl -plaintext \ + -H "authorization: Bearer ${token}" \ + -import-path "${proto_root}/plasmavmc" \ + -proto "${proto_root}/plasmavmc/plasmavmc.proto" \ + -d "$(jq -cn --arg org "${org_id}" --arg project "${project_id}" --arg vm "${vm_id}" '{orgId:$org, projectId:$project, vmId:$vm}')" \ + 127.0.0.1:50082 plasmavmc.v1.VmService/WatchVm \ + >"${output_path}" 2>"${output_path}.stderr" & +echo $! >"${output_path}.pid" +EOS +} + +wait_for_plasmavmc_vm_watch_completion() { + local node="$1" + local output_path="$2" + local timeout="${3:-60}" + local deadline=$((SECONDS + timeout)) + + while true; do + if ssh_node_script "${node}" "${output_path}" <<'EOS' +set -euo pipefail +output_path="$1" +if [[ ! -f "${output_path}.pid" ]]; then + exit 0 +fi +pid="$(cat "${output_path}.pid")" +if kill -0 "${pid}" >/dev/null 2>&1; then + exit 1 +fi +EOS + then + return 0 + fi + if (( SECONDS >= deadline )); then + ssh_node "${node}" "test -f ${output_path}.stderr && cat ${output_path}.stderr || true" >&2 || true + ssh_node "${node}" "test -f ${output_path} && cat ${output_path} || true" >&2 || true + die "timed out waiting for PlasmaVMC watch stream to exit" + fi + sleep 1 + done +} + +assert_plasmavmc_vm_watch_events() { + local node="$1" + local output_path="$2" + local vm_id="$3" + + ssh_node_script "${node}" "${output_path}" "${vm_id}" <<'EOS' +set -euo pipefail +output_path="$1" +vm_id="$2" +[[ -s "${output_path}" ]] || { + echo "PlasmaVMC watch output is empty" >&2 + test -f "${output_path}.stderr" && cat "${output_path}.stderr" >&2 || true + exit 1 +} +jq -s --arg vm "${vm_id}" ' + any(.vmId == $vm and .eventType == "VM_EVENT_TYPE_STATE_CHANGED" and .vm.state == "VM_STATE_RUNNING") and + any(.vmId == $vm and .eventType == "VM_EVENT_TYPE_STATE_CHANGED" and .vm.state == "VM_STATE_STOPPED") and + any(.vmId == $vm and .eventType == "VM_EVENT_TYPE_DELETED") +' "${output_path}" >/dev/null || { + cat "${output_path}" >&2 + test -f "${output_path}.stderr" && cat "${output_path}.stderr" >&2 || true + exit 1 +} +EOS +} + wait_requested() { local nodes mapfile -t nodes < <(all_or_requested_nodes "$@") @@ -3707,6 +3792,7 @@ validate_vm_storage_flow() { node04_coronafs_tunnel="$(start_ssh_tunnel node04 25088 "${CORONAFS_API_PORT}")" node05_coronafs_tunnel="$(start_ssh_tunnel node05 35088 "${CORONAFS_API_PORT}")" local image_source_path="" + local vm_watch_output="" local node01_proto_root="/var/lib/plasmavmc/test-protos" local vpc_id="" subnet_id="" port_id="" port_ip="" port_mac="" cleanup_vm_storage_flow() { @@ -3737,6 +3823,9 @@ validate_vm_storage_flow() { if [[ -n "${image_source_path}" && "${image_source_path}" != /nix/store/* ]]; then ssh_node node01 "rm -f ${image_source_path}" >/dev/null 2>&1 || true fi + if [[ -n "${vm_watch_output}" ]]; then + ssh_node node01 "rm -f ${vm_watch_output} ${vm_watch_output}.pid ${vm_watch_output}.stderr" >/dev/null 2>&1 || true + fi stop_ssh_tunnel node05 "${node05_coronafs_tunnel}" stop_ssh_tunnel node04 "${node04_coronafs_tunnel}" stop_ssh_tunnel node01 "${coronafs_tunnel}" @@ -3993,6 +4082,9 @@ EOS )" vm_id="$(printf '%s' "${create_response}" | jq -r '.id')" [[ -n "${vm_id}" && "${vm_id}" != "null" ]] || die "failed to create VM through PlasmaVMC" + vm_watch_output="/tmp/plasmavmc-watch-${vm_id}.json" + start_plasmavmc_vm_watch node01 "${node01_proto_root}" "${token}" "${org_id}" "${project_id}" "${vm_id}" "${vm_watch_output}" + sleep 2 local get_vm_json get_vm_json="$(jq -cn --arg org "${org_id}" --arg project "${project_id}" --arg vm "${vm_id}" '{orgId:$org, projectId:$project, vmId:$vm}')" @@ -4420,6 +4512,8 @@ EOS fi sleep 2 done + wait_for_plasmavmc_vm_watch_completion node01 "${vm_watch_output}" 60 + assert_plasmavmc_vm_watch_events node01 "${vm_watch_output}" "${vm_id}" wait_for_prismnet_port_detachment "${token}" "${org_id}" "${project_id}" "${subnet_id}" "${port_id}" >/dev/null ssh_node "${node_id}" "bash -lc '[[ ! -d $(printf '%q' "$(vm_runtime_dir_path "${vm_id}")") ]]'" diff --git a/plasmavmc/crates/plasmavmc-server/src/main.rs b/plasmavmc/crates/plasmavmc-server/src/main.rs index 9c8cdcb..1adbedd 100644 --- a/plasmavmc/crates/plasmavmc-server/src/main.rs +++ b/plasmavmc/crates/plasmavmc-server/src/main.rs @@ -1,32 +1,32 @@ //! PlasmaVMC control plane server binary use clap::Parser; +use iam_service_auth::AuthService; use metrics_exporter_prometheus::PrometheusBuilder; use plasmavmc_api::proto::image_service_server::ImageServiceServer; -use plasmavmc_api::proto::node_service_server::NodeServiceServer; use plasmavmc_api::proto::node_service_client::NodeServiceClient; -use plasmavmc_api::proto::volume_service_server::VolumeServiceServer; +use plasmavmc_api::proto::node_service_server::NodeServiceServer; use plasmavmc_api::proto::vm_service_server::VmServiceServer; +use plasmavmc_api::proto::volume_service_server::VolumeServiceServer; use plasmavmc_api::proto::{ HeartbeatNodeRequest, HypervisorType as ProtoHypervisorType, NodeCapacity, NodeState as ProtoNodeState, VolumeDriverKind as ProtoVolumeDriverKind, }; +use plasmavmc_firecracker::FireCrackerBackend; use plasmavmc_hypervisor::HypervisorRegistry; use plasmavmc_kvm::KvmBackend; -use plasmavmc_firecracker::FireCrackerBackend; -use iam_service_auth::AuthService; use plasmavmc_server::config::ServerConfig; -use plasmavmc_server::VmServiceImpl; use plasmavmc_server::watcher::{StateSynchronizer, StateWatcher, WatcherConfig}; +use plasmavmc_server::VmServiceImpl; use std::net::SocketAddr; use std::path::PathBuf; use std::sync::Arc; -use tonic::transport::{Certificate, Endpoint, Identity, Server, ServerTlsConfig}; -use tonic_health::server::health_reporter; -use tonic::{Request, Status}; -use tracing_subscriber::EnvFilter; use std::time::Duration; use std::{collections::HashMap, fs}; +use tonic::transport::{Certificate, Endpoint, Identity, Server, ServerTlsConfig}; +use tonic::{Request, Status}; +use tonic_health::server::health_reporter; +use tracing_subscriber::EnvFilter; /// PlasmaVMC control plane server #[derive(Parser, Debug)] @@ -175,7 +175,10 @@ async fn main() -> Result<(), Box> { let contents = tokio::fs::read_to_string(&args.config).await?; toml::from_str(&contents)? } else { - tracing::info!("Config file not found: {}, using defaults", args.config.display()); + tracing::info!( + "Config file not found: {}, using defaults", + args.config.display() + ); ServerConfig::default() }; @@ -246,13 +249,13 @@ async fn main() -> Result<(), Box> { tracing::debug!("FireCracker backend not available (missing kernel/rootfs paths)"); } - tracing::info!( - "Registered hypervisors: {:?}", - registry.available() - ); + tracing::info!("Registered hypervisors: {:?}", registry.available()); // Initialize IAM authentication service - tracing::info!("Connecting to IAM server at {}", config.auth.iam_server_addr); + tracing::info!( + "Connecting to IAM server at {}", + config.auth.iam_server_addr + ); let auth_service = AuthService::new(&config.auth.iam_server_addr) .await .map_err(|e| format!("Failed to connect to IAM server: {}", e))?; @@ -278,8 +281,12 @@ async fn main() -> Result<(), Box> { // Create services let vm_service = Arc::new( - VmServiceImpl::new(registry, auth_service.clone(), config.auth.iam_server_addr.clone()) - .await?, + VmServiceImpl::new( + registry, + auth_service.clone(), + config.auth.iam_server_addr.clone(), + ) + .await?, ); // Optional: start state watcher for multi-instance HA sync @@ -288,7 +295,7 @@ async fn main() -> Result<(), Box> { .unwrap_or(false) { let config = WatcherConfig::default(); - let (watcher, rx) = StateWatcher::new(config); + let (watcher, rx) = StateWatcher::new(vm_service.store(), config); let synchronizer = StateSynchronizer::new(vm_service.clone()); tokio::spawn(async move { if let Err(e) = watcher.start().await { @@ -307,7 +314,9 @@ async fn main() -> Result<(), Box> { .and_then(|v| v.parse::().ok()) { if secs > 0 { - vm_service.clone().start_health_monitor(Duration::from_secs(secs)); + vm_service + .clone() + .start_health_monitor(Duration::from_secs(secs)); } } @@ -321,12 +330,10 @@ async fn main() -> Result<(), Box> { .ok() .and_then(|v| v.parse::().ok()) .unwrap_or(60); - vm_service - .clone() - .start_node_health_monitor( - Duration::from_secs(interval_secs), - Duration::from_secs(timeout_secs), - ); + vm_service.clone().start_node_health_monitor( + Duration::from_secs(interval_secs), + Duration::from_secs(timeout_secs), + ); } } diff --git a/plasmavmc/crates/plasmavmc-server/src/storage.rs b/plasmavmc/crates/plasmavmc-server/src/storage.rs index 82c9afc..a091129 100644 --- a/plasmavmc/crates/plasmavmc-server/src/storage.rs +++ b/plasmavmc/crates/plasmavmc-server/src/storage.rs @@ -68,6 +68,9 @@ pub trait VmStore: Send + Sync { /// List all VMs for a tenant async fn list_vms(&self, org_id: &str, project_id: &str) -> StorageResult>; + /// List all VMs across every tenant. + async fn list_all_vms(&self) -> StorageResult>; + /// Save a VM handle async fn save_handle( &self, @@ -431,6 +434,17 @@ impl VmStore for FlareDBStore { Ok(vms) } + async fn list_all_vms(&self) -> StorageResult> { + let mut vms = Vec::new(); + for value in self.cas_scan_values("/plasmavmc/vms/").await? { + if let Ok(vm) = serde_json::from_slice::(&value) { + vms.push(vm); + } + } + + Ok(vms) + } + async fn save_handle( &self, org_id: &str, @@ -580,7 +594,9 @@ impl VmStore for FlareDBStore { client .cas(key.as_bytes().to_vec(), value, expected_version) .await - .map_err(|e| StorageError::FlareDB(format!("FlareDB CAS volume save failed: {}", e)))? + .map_err(|e| { + StorageError::FlareDB(format!("FlareDB CAS volume save failed: {}", e)) + })? }; Ok(success) } @@ -714,6 +730,10 @@ impl VmStore for FileStore { .collect()) } + async fn list_all_vms(&self) -> StorageResult> { + Ok(self.load_state().unwrap_or_default().vms) + } + async fn save_handle( &self, _org_id: &str, diff --git a/plasmavmc/crates/plasmavmc-server/src/vm_service.rs b/plasmavmc/crates/plasmavmc-server/src/vm_service.rs index a8d7b13..9853b87 100644 --- a/plasmavmc/crates/plasmavmc-server/src/vm_service.rs +++ b/plasmavmc/crates/plasmavmc-server/src/vm_service.rs @@ -30,10 +30,11 @@ use plasmavmc_api::proto::{ NodeState as ProtoNodeState, OsType as ProtoOsType, PrepareVmMigrationRequest, RebootVmRequest, RecoverVmRequest, RegisterExternalVolumeRequest, ResetVmRequest, ResizeVolumeRequest, StartVmRequest, StopVmRequest, UncordonNodeRequest, UpdateImageRequest, UpdateVmRequest, - VirtualMachine, Visibility as ProtoVisibility, VmEvent, VmSpec as ProtoVmSpec, - VmState as ProtoVmState, VmStatus as ProtoVmStatus, Volume as ProtoVolume, - VolumeBacking as ProtoVolumeBacking, VolumeDriverKind as ProtoVolumeDriverKind, - VolumeFormat as ProtoVolumeFormat, VolumeStatus as ProtoVolumeStatus, WatchVmRequest, + VirtualMachine, Visibility as ProtoVisibility, VmEvent, VmEventType as ProtoVmEventType, + VmSpec as ProtoVmSpec, VmState as ProtoVmState, VmStatus as ProtoVmStatus, + Volume as ProtoVolume, VolumeBacking as ProtoVolumeBacking, + VolumeDriverKind as ProtoVolumeDriverKind, VolumeFormat as ProtoVolumeFormat, + VolumeStatus as ProtoVolumeStatus, WatchVmRequest, }; use plasmavmc_hypervisor::HypervisorRegistry; use plasmavmc_types::{ @@ -247,6 +248,10 @@ impl VmServiceImpl { .unwrap_or(true) } + pub fn store(&self) -> Arc { + Arc::clone(&self.store) + } + fn to_status_code(err: plasmavmc_types::Error) -> Status { Status::internal(err.to_string()) } @@ -287,6 +292,61 @@ impl VmServiceImpl { .as_secs() } + fn watch_poll_interval() -> Duration { + let poll_interval_ms = std::env::var("PLASMAVMC_VM_WATCH_POLL_INTERVAL_MS") + .ok() + .and_then(|value| value.parse::().ok()) + .unwrap_or(500) + .max(100); + Duration::from_millis(poll_interval_ms) + } + + fn vm_values_differ( + previous: &plasmavmc_types::VirtualMachine, + current: &plasmavmc_types::VirtualMachine, + ) -> Result { + let previous = serde_json::to_value(previous).map_err(|error| { + Status::internal(format!("failed to serialize VM snapshot: {error}")) + })?; + let current = serde_json::to_value(current).map_err(|error| { + Status::internal(format!("failed to serialize VM snapshot: {error}")) + })?; + Ok(previous != current) + } + + fn build_vm_event( + previous: Option<&plasmavmc_types::VirtualMachine>, + current: Option<&plasmavmc_types::VirtualMachine>, + ) -> Result, Status> { + let event_type = match (previous, current) { + (None, Some(_)) => ProtoVmEventType::Created, + (Some(_), None) => ProtoVmEventType::Deleted, + (Some(previous), Some(current)) => { + if previous.state != current.state + || previous.status.actual_state != current.status.actual_state + { + ProtoVmEventType::StateChanged + } else if Self::vm_values_differ(previous, current)? { + ProtoVmEventType::Updated + } else { + return Ok(None); + } + } + (None, None) => return Ok(None), + }; + + let vm = current + .or(previous) + .ok_or_else(|| Status::internal("watch event builder received an empty VM snapshot"))?; + + Ok(Some(VmEvent { + vm_id: vm.id.to_string(), + event_type: event_type as i32, + vm: Some(Self::to_proto_vm(vm, vm.status.clone())), + timestamp: Self::now_epoch() as i64, + })) + } + fn endpoint_host(endpoint: &str) -> Result { let authority = endpoint .split("://") @@ -1925,6 +1985,7 @@ impl VmServiceImpl { #[cfg(test)] mod tests { use super::*; + use plasmavmc_types::VmSpec; #[test] fn unspecified_disk_cache_defaults_to_writeback() { @@ -1945,6 +2006,47 @@ mod tests { prismnet_api::proto::DeviceType::None as i32 ); } + + #[test] + fn build_vm_event_classifies_lifecycle_changes() { + let vm = plasmavmc_types::VirtualMachine::new( + "watch-vm", + "watch-org", + "watch-project", + VmSpec::default(), + ); + + let created = VmServiceImpl::build_vm_event(None, Some(&vm)) + .unwrap() + .unwrap(); + assert_eq!(created.event_type, ProtoVmEventType::Created as i32); + + let mut updated_vm = vm.clone(); + updated_vm.updated_at += 1; + updated_vm + .metadata + .insert("note".to_string(), "changed".to_string()); + let updated = VmServiceImpl::build_vm_event(Some(&vm), Some(&updated_vm)) + .unwrap() + .unwrap(); + assert_eq!(updated.event_type, ProtoVmEventType::Updated as i32); + + let mut running_vm = updated_vm.clone(); + running_vm.state = VmState::Running; + running_vm.status.actual_state = VmState::Running; + let state_changed = VmServiceImpl::build_vm_event(Some(&updated_vm), Some(&running_vm)) + .unwrap() + .unwrap(); + assert_eq!( + state_changed.event_type, + ProtoVmEventType::StateChanged as i32 + ); + + let deleted = VmServiceImpl::build_vm_event(Some(&running_vm), None) + .unwrap() + .unwrap(); + assert_eq!(deleted.event_type, ProtoVmEventType::Deleted as i32); + } } impl StateSink for VmServiceImpl { @@ -3742,7 +3844,7 @@ impl VmService for VmServiceImpl { vm_id = %req.vm_id, org_id = %req.org_id, project_id = %req.project_id, - "WatchVm request (stub implementation)" + "WatchVm request" ); self.auth .authorize( @@ -3752,8 +3854,68 @@ impl VmService for VmServiceImpl { ) .await?; - // TODO: Implement VM watch via ChainFire watch - Err(Status::unimplemented("VM watch not yet implemented")) + let initial_vm = self + .ensure_vm_loaded(&req.org_id, &req.project_id, &req.vm_id) + .await + .ok_or_else(|| Status::not_found("VM not found"))?; + let poll_interval = Self::watch_poll_interval(); + let store = Arc::clone(&self.store); + let org_id = req.org_id.clone(); + let project_id = req.project_id.clone(); + let vm_id = req.vm_id.clone(); + let (tx, rx) = tokio::sync::mpsc::channel(32); + + tokio::spawn(async move { + let mut last_seen = Some(initial_vm); + let mut ticker = tokio::time::interval(poll_interval); + ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + loop { + ticker.tick().await; + let next_vm = match tokio::time::timeout( + STORE_OP_TIMEOUT, + store.load_vm(&org_id, &project_id, &vm_id), + ) + .await + { + Ok(Ok(vm)) => vm, + Ok(Err(error)) => { + let _ = tx + .send(Err(Status::unavailable(format!( + "VM watch poll failed: {error}" + )))) + .await; + break; + } + Err(_) => { + let _ = tx + .send(Err(Status::deadline_exceeded("VM watch poll timed out"))) + .await; + break; + } + }; + + let event = match Self::build_vm_event(last_seen.as_ref(), next_vm.as_ref()) { + Ok(event) => event, + Err(status) => { + let _ = tx.send(Err(status)).await; + break; + } + }; + if let Some(event) = event { + if tx.send(Ok(event)).await.is_err() { + break; + } + } + + match next_vm { + Some(vm) => last_seen = Some(vm), + None => break, + } + } + }); + + Ok(Response::new(ReceiverStream::new(rx))) } } diff --git a/plasmavmc/crates/plasmavmc-server/src/watcher.rs b/plasmavmc/crates/plasmavmc-server/src/watcher.rs index 2ff00aa..1a4f04a 100644 --- a/plasmavmc/crates/plasmavmc-server/src/watcher.rs +++ b/plasmavmc/crates/plasmavmc-server/src/watcher.rs @@ -1,14 +1,23 @@ -//! ChainFire state watcher for PlasmaVMC +//! Storage-backed state watcher for PlasmaVMC. //! -//! Provides state synchronization across multiple PlasmaVMC instances -//! by watching ChainFire for VM and handle changes made by other nodes. +//! PlasmaVMC persists VM intent and node heartbeats in the shared metadata +//! store. This watcher polls that store and mirrors external changes into each +//! process-local cache so multiple control-plane or agent instances converge on +//! the same durable view. -use chainfire_client::{Client as ChainFireClient, EventType, WatchEvent}; +use crate::storage::VmStore; use plasmavmc_types::{Node, VirtualMachine, VmHandle}; +use serde::Serialize; +use std::collections::HashMap; use std::sync::Arc; +use std::time::Duration; use tokio::sync::mpsc; +use tokio::time::MissedTickBehavior; use tracing::{debug, info, warn}; +type VmSnapshot = HashMap; +type NodeSnapshot = HashMap; + /// Event types from the state watcher #[derive(Debug, Clone)] pub enum StateEvent { @@ -39,255 +48,246 @@ pub enum StateEvent { vm_id: String, }, /// A node was updated - NodeUpdated { - node_id: String, - node: Node, - }, + NodeUpdated { node_id: String, node: Node }, /// A node was deleted - NodeDeleted { - node_id: String, - }, + NodeDeleted { node_id: String }, } /// Configuration for the state watcher #[derive(Debug, Clone)] pub struct WatcherConfig { - /// ChainFire endpoint - pub chainfire_endpoint: String, + /// Poll interval for metadata refresh. + pub poll_interval: Duration, /// Channel buffer size for events pub buffer_size: usize, } impl Default for WatcherConfig { fn default() -> Self { + let poll_interval_ms = std::env::var("PLASMAVMC_STATE_WATCHER_POLL_INTERVAL_MS") + .ok() + .and_then(|value| value.parse::().ok()) + .unwrap_or(1000) + .max(100); Self { - chainfire_endpoint: std::env::var("PLASMAVMC_CHAINFIRE_ENDPOINT") - .unwrap_or_else(|_| "http://127.0.0.1:2379".to_string()), + poll_interval: Duration::from_millis(poll_interval_ms), buffer_size: 256, } } } -/// State watcher that monitors ChainFire for external changes +/// State watcher that monitors the shared VM store for external changes. pub struct StateWatcher { + store: Arc, config: WatcherConfig, event_tx: mpsc::Sender, } impl StateWatcher { - /// Create a new state watcher and return the event receiver - pub fn new(config: WatcherConfig) -> (Self, mpsc::Receiver) { + /// Create a new state watcher and return the event receiver. + pub fn new( + store: Arc, + config: WatcherConfig, + ) -> (Self, mpsc::Receiver) { let (event_tx, event_rx) = mpsc::channel(config.buffer_size); - (Self { config, event_tx }, event_rx) + ( + Self { + store, + config, + event_tx, + }, + event_rx, + ) } - /// Start watching for state changes + /// Start watching for state changes. /// - /// This spawns background tasks that watch: - /// - `/plasmavmc/vms/` prefix for VM changes - /// - `/plasmavmc/handles/` prefix for handle changes - /// - `/plasmavmc/nodes/` prefix for node changes + /// VM handles remain local process state because they point at ephemeral + /// runtime artifacts such as QMP sockets and per-node directories. Durable + /// VM intent and node heartbeats are the shared state synchronized here. pub async fn start(&self) -> Result<(), WatcherError> { - info!("Starting PlasmaVMC state watcher"); + info!( + poll_interval_ms = self.config.poll_interval.as_millis(), + "Starting storage-backed PlasmaVMC state watcher" + ); - // Connect to ChainFire - let mut client = ChainFireClient::connect(&self.config.chainfire_endpoint) - .await - .map_err(|e| WatcherError::Connection(e.to_string()))?; + let mut vm_snapshot = self.load_vms().await?; + let mut node_snapshot = self.load_nodes().await?; + let mut ticker = tokio::time::interval(self.config.poll_interval); + ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); - // Start watching VMs - let vm_watch = client - .watch_prefix(b"/plasmavmc/vms/") - .await - .map_err(|e| WatcherError::Watch(e.to_string()))?; + loop { + ticker.tick().await; - let event_tx_vm = self.event_tx.clone(); - tokio::spawn(async move { - Self::watch_loop(vm_watch, event_tx_vm, WatchType::Vm).await; - }); - - // Connect again for second watch (each watch uses its own stream) - let mut client2 = ChainFireClient::connect(&self.config.chainfire_endpoint) - .await - .map_err(|e| WatcherError::Connection(e.to_string()))?; - - // Start watching handles - let handle_watch = client2 - .watch_prefix(b"/plasmavmc/handles/") - .await - .map_err(|e| WatcherError::Watch(e.to_string()))?; - - let event_tx_handle = self.event_tx.clone(); - tokio::spawn(async move { - Self::watch_loop(handle_watch, event_tx_handle, WatchType::Handle).await; - }); - - // Connect again for node watch - let mut client3 = ChainFireClient::connect(&self.config.chainfire_endpoint) - .await - .map_err(|e| WatcherError::Connection(e.to_string()))?; - - let node_watch = client3 - .watch_prefix(b"/plasmavmc/nodes/") - .await - .map_err(|e| WatcherError::Watch(e.to_string()))?; - - let event_tx_node = self.event_tx.clone(); - tokio::spawn(async move { - Self::watch_loop(node_watch, event_tx_node, WatchType::Node).await; - }); - - info!("State watcher started successfully"); - Ok(()) - } - - /// Watch loop for processing events - async fn watch_loop( - mut watch: chainfire_client::WatchHandle, - event_tx: mpsc::Sender, - watch_type: WatchType, - ) { - debug!(?watch_type, "Starting watch loop"); - - while let Some(event) = watch.recv().await { - match Self::process_event(&event, &watch_type) { - Ok(Some(state_event)) => { - if event_tx.send(state_event).await.is_err() { - warn!("Event receiver dropped, stopping watch loop"); - break; + let next_vm_snapshot = match self.load_vms().await { + Ok(snapshot) => snapshot, + Err(error) => { + warn!(error = %error, "Failed to refresh PlasmaVMC VM snapshot"); + continue; + } + }; + match diff_vm_snapshots(&vm_snapshot, &next_vm_snapshot) { + Ok(events) => { + for event in events { + if self.event_tx.send(event).await.is_err() { + info!("State watcher receiver dropped, stopping"); + return Ok(()); + } } + vm_snapshot = next_vm_snapshot; } - Ok(None) => { - // Event was filtered or not relevant + Err(error) => { + warn!(error = %error, "Failed to diff PlasmaVMC VM snapshot"); } - Err(e) => { - warn!(?watch_type, error = %e, "Failed to process watch event"); + } + + let next_node_snapshot = match self.load_nodes().await { + Ok(snapshot) => snapshot, + Err(error) => { + warn!(error = %error, "Failed to refresh PlasmaVMC node snapshot"); + continue; + } + }; + match diff_node_snapshots(&node_snapshot, &next_node_snapshot) { + Ok(events) => { + for event in events { + if self.event_tx.send(event).await.is_err() { + info!("State watcher receiver dropped, stopping"); + return Ok(()); + } + } + node_snapshot = next_node_snapshot; + } + Err(error) => { + warn!(error = %error, "Failed to diff PlasmaVMC node snapshot"); } } } - - debug!(?watch_type, "Watch loop ended"); } - /// Process a watch event into a state event - fn process_event( - event: &WatchEvent, - watch_type: &WatchType, - ) -> Result, WatcherError> { - let key_str = String::from_utf8_lossy(&event.key); + async fn load_vms(&self) -> Result { + let vms = self + .store + .list_all_vms() + .await + .map_err(|error| WatcherError::Storage(error.to_string()))?; + let mut snapshot = HashMap::with_capacity(vms.len()); + for vm in vms { + snapshot.insert(VmKey::from_vm(&vm), vm); + } + Ok(snapshot) + } - // Parse the key to extract org_id, project_id, vm_id - let (org_id, project_id, vm_id) = match watch_type { - WatchType::Vm => parse_vm_key(&key_str)?, - WatchType::Handle => parse_handle_key(&key_str)?, - WatchType::Node => (String::new(), String::new(), parse_node_key(&key_str)?), + async fn load_nodes(&self) -> Result { + let nodes = self + .store + .list_nodes() + .await + .map_err(|error| WatcherError::Storage(error.to_string()))?; + let mut snapshot = HashMap::with_capacity(nodes.len()); + for node in nodes { + snapshot.insert(node.id.to_string(), node); + } + Ok(snapshot) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct VmKey { + org_id: String, + project_id: String, + vm_id: String, +} + +impl VmKey { + fn from_vm(vm: &VirtualMachine) -> Self { + Self { + org_id: vm.org_id.clone(), + project_id: vm.project_id.clone(), + vm_id: vm.id.to_string(), + } + } +} + +fn diff_vm_snapshots( + previous: &VmSnapshot, + current: &VmSnapshot, +) -> Result, WatcherError> { + let mut events = Vec::new(); + + for (key, vm) in current { + let changed = match previous.get(key) { + Some(old_vm) => values_differ(old_vm, vm)?, + None => true, }; - - match event.event_type { - EventType::Put => { - match watch_type { - WatchType::Vm => { - let vm: VirtualMachine = serde_json::from_slice(&event.value) - .map_err(|e| WatcherError::Deserialize(e.to_string()))?; - Ok(Some(StateEvent::VmUpdated { - org_id, - project_id, - vm_id, - vm, - })) - } - WatchType::Handle => { - let handle: VmHandle = serde_json::from_slice(&event.value) - .map_err(|e| WatcherError::Deserialize(e.to_string()))?; - Ok(Some(StateEvent::HandleUpdated { - org_id, - project_id, - vm_id, - handle, - })) - } - WatchType::Node => { - let node: Node = serde_json::from_slice(&event.value) - .map_err(|e| WatcherError::Deserialize(e.to_string()))?; - Ok(Some(StateEvent::NodeUpdated { - node_id: vm_id, - node, - })) - } - } - } - EventType::Delete => { - match watch_type { - WatchType::Vm => Ok(Some(StateEvent::VmDeleted { - org_id, - project_id, - vm_id, - })), - WatchType::Handle => Ok(Some(StateEvent::HandleDeleted { - org_id, - project_id, - vm_id, - })), - WatchType::Node => Ok(Some(StateEvent::NodeDeleted { node_id: vm_id })), - } - } + if changed { + events.push(StateEvent::VmUpdated { + org_id: key.org_id.clone(), + project_id: key.project_id.clone(), + vm_id: key.vm_id.clone(), + vm: vm.clone(), + }); } } -} -#[derive(Debug, Clone, Copy)] -enum WatchType { - Vm, - Handle, - Node, -} - -/// Parse VM key: /plasmavmc/vms/{org_id}/{project_id}/{vm_id} -fn parse_vm_key(key: &str) -> Result<(String, String, String), WatcherError> { - let parts: Vec<&str> = key.trim_start_matches('/').split('/').collect(); - if parts.len() < 5 || parts[0] != "plasmavmc" || parts[1] != "vms" { - return Err(WatcherError::InvalidKey(key.to_string())); + for key in previous.keys() { + if !current.contains_key(key) { + events.push(StateEvent::VmDeleted { + org_id: key.org_id.clone(), + project_id: key.project_id.clone(), + vm_id: key.vm_id.clone(), + }); + } } - Ok(( - parts[2].to_string(), - parts[3].to_string(), - parts[4].to_string(), - )) + + Ok(events) } -/// Parse handle key: /plasmavmc/handles/{org_id}/{project_id}/{vm_id} -fn parse_handle_key(key: &str) -> Result<(String, String, String), WatcherError> { - let parts: Vec<&str> = key.trim_start_matches('/').split('/').collect(); - if parts.len() < 5 || parts[0] != "plasmavmc" || parts[1] != "handles" { - return Err(WatcherError::InvalidKey(key.to_string())); +fn diff_node_snapshots( + previous: &NodeSnapshot, + current: &NodeSnapshot, +) -> Result, WatcherError> { + let mut events = Vec::new(); + + for (node_id, node) in current { + let changed = match previous.get(node_id) { + Some(old_node) => values_differ(old_node, node)?, + None => true, + }; + if changed { + events.push(StateEvent::NodeUpdated { + node_id: node_id.clone(), + node: node.clone(), + }); + } } - Ok(( - parts[2].to_string(), - parts[3].to_string(), - parts[4].to_string(), - )) + + for node_id in previous.keys() { + if !current.contains_key(node_id) { + events.push(StateEvent::NodeDeleted { + node_id: node_id.clone(), + }); + } + } + + Ok(events) } -/// Parse node key: /plasmavmc/nodes/{node_id} -fn parse_node_key(key: &str) -> Result { - let parts: Vec<&str> = key.trim_start_matches('/').split('/').collect(); - if parts.len() < 3 || parts[0] != "plasmavmc" || parts[1] != "nodes" { - return Err(WatcherError::InvalidKey(key.to_string())); - } - Ok(parts[2].to_string()) +fn values_differ(lhs: &T, rhs: &T) -> Result { + let lhs = + serde_json::to_value(lhs).map_err(|error| WatcherError::Serialize(error.to_string()))?; + let rhs = + serde_json::to_value(rhs).map_err(|error| WatcherError::Serialize(error.to_string()))?; + Ok(lhs != rhs) } /// Watcher errors #[derive(Debug, thiserror::Error)] pub enum WatcherError { - #[error("Connection error: {0}")] - Connection(String), - #[error("Watch error: {0}")] - Watch(String), - #[error("Invalid key format: {0}")] - InvalidKey(String), - #[error("Deserialization error: {0}")] - Deserialize(String), + #[error("Storage error: {0}")] + Storage(String), + #[error("Serialization error: {0}")] + Serialize(String), } /// State synchronizer that applies watch events to local state @@ -323,20 +323,42 @@ impl StateSynchronizer { while let Some(event) = event_rx.recv().await { match event { - StateEvent::VmUpdated { org_id, project_id, vm_id, vm } => { + StateEvent::VmUpdated { + org_id, + project_id, + vm_id, + vm, + } => { debug!(org_id, project_id, vm_id, "External VM update received"); self.sink.on_vm_updated(&org_id, &project_id, &vm_id, vm); } - StateEvent::VmDeleted { org_id, project_id, vm_id } => { + StateEvent::VmDeleted { + org_id, + project_id, + vm_id, + } => { debug!(org_id, project_id, vm_id, "External VM deletion received"); self.sink.on_vm_deleted(&org_id, &project_id, &vm_id); } - StateEvent::HandleUpdated { org_id, project_id, vm_id, handle } => { + StateEvent::HandleUpdated { + org_id, + project_id, + vm_id, + handle, + } => { debug!(org_id, project_id, vm_id, "External handle update received"); - self.sink.on_handle_updated(&org_id, &project_id, &vm_id, handle); + self.sink + .on_handle_updated(&org_id, &project_id, &vm_id, handle); } - StateEvent::HandleDeleted { org_id, project_id, vm_id } => { - debug!(org_id, project_id, vm_id, "External handle deletion received"); + StateEvent::HandleDeleted { + org_id, + project_id, + vm_id, + } => { + debug!( + org_id, + project_id, vm_id, "External handle deletion received" + ); self.sink.on_handle_deleted(&org_id, &project_id, &vm_id); } StateEvent::NodeUpdated { node_id, node } => { @@ -357,33 +379,82 @@ impl StateSynchronizer { #[cfg(test)] mod tests { use super::*; + use plasmavmc_types::{VmSpec, VmState}; - #[test] - fn test_parse_vm_key() { - let (org, proj, vm) = parse_vm_key("/plasmavmc/vms/org1/proj1/vm-123").unwrap(); - assert_eq!(org, "org1"); - assert_eq!(proj, "proj1"); - assert_eq!(vm, "vm-123"); + fn sample_vm() -> VirtualMachine { + VirtualMachine::new("vm-a", "org-a", "project-a", VmSpec::default()) + } + + fn sample_node() -> Node { + Node::new("node-a") } #[test] - fn test_parse_handle_key() { - let (org, proj, vm) = parse_handle_key("/plasmavmc/handles/org1/proj1/vm-123").unwrap(); - assert_eq!(org, "org1"); - assert_eq!(proj, "proj1"); - assert_eq!(vm, "vm-123"); + fn vm_snapshot_diff_emits_update_for_new_and_changed_vms() { + let mut previous = VmSnapshot::new(); + let mut current = VmSnapshot::new(); + let vm = sample_vm(); + let key = VmKey::from_vm(&vm); + + current.insert(key.clone(), vm.clone()); + let events = diff_vm_snapshots(&previous, ¤t).unwrap(); + assert!(matches!( + events.as_slice(), + [StateEvent::VmUpdated { vm_id, .. }] if vm_id == &vm.id.to_string() + )); + + previous = current.clone(); + let mut updated_vm = vm.clone(); + updated_vm.state = VmState::Running; + updated_vm.status.actual_state = VmState::Running; + current.insert(key, updated_vm.clone()); + let events = diff_vm_snapshots(&previous, ¤t).unwrap(); + assert!(matches!( + events.as_slice(), + [StateEvent::VmUpdated { vm, .. }] if vm.state == VmState::Running + )); } #[test] - fn test_parse_node_key() { - let node_id = parse_node_key("/plasmavmc/nodes/node-1").unwrap(); - assert_eq!(node_id, "node-1"); + fn vm_snapshot_diff_emits_delete_for_removed_vms() { + let vm = sample_vm(); + let key = VmKey::from_vm(&vm); + let previous = HashMap::from([(key, vm.clone())]); + let current = VmSnapshot::new(); + + let events = diff_vm_snapshots(&previous, ¤t).unwrap(); + assert!(matches!( + events.as_slice(), + [StateEvent::VmDeleted { vm_id, .. }] if vm_id == &vm.id.to_string() + )); } #[test] - fn test_invalid_key() { - assert!(parse_vm_key("/invalid/key").is_err()); - assert!(parse_handle_key("/plasmavmc/wrong/a/b/c").is_err()); - assert!(parse_node_key("/plasmavmc/wrong").is_err()); + fn node_snapshot_diff_emits_update_and_delete() { + let node = sample_node(); + let mut previous = NodeSnapshot::new(); + let mut current = NodeSnapshot::from([(node.id.to_string(), node.clone())]); + + let events = diff_node_snapshots(&previous, ¤t).unwrap(); + assert!(matches!( + events.as_slice(), + [StateEvent::NodeUpdated { node_id, .. }] if node_id == &node.id.to_string() + )); + + previous = current.clone(); + let mut updated_node = node.clone(); + updated_node.last_heartbeat = 42; + current.insert(updated_node.id.to_string(), updated_node); + let events = diff_node_snapshots(&previous, ¤t).unwrap(); + assert!(matches!( + events.as_slice(), + [StateEvent::NodeUpdated { node, .. }] if node.last_heartbeat == 42 + )); + + let events = diff_node_snapshots(¤t, &NodeSnapshot::new()).unwrap(); + assert!(matches!( + events.as_slice(), + [StateEvent::NodeDeleted { node_id }] if node_id == &node.id.to_string() + )); } }