//! ChainFire-based Placement Driver client //! //! This module provides a client for interacting with ChainFire as the //! Placement Driver (PD) for FlareDB cluster management. //! //! ## Watch Integration //! //! The client supports real-time notifications of metadata changes via //! ChainFire's Watch API, enabling event-driven updates instead of polling. use flaredb_proto::chainfire::cluster_client::ClusterClient; use flaredb_proto::chainfire::kv_client::KvClient; use flaredb_proto::chainfire::watch_client::WatchClient; use flaredb_proto::chainfire::{ Event, PutRequest, RangeRequest, StatusRequest, WatchCreateRequest, WatchRequest, }; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{broadcast, mpsc, RwLock}; use tonic::transport::Channel; /// Key prefixes for cluster metadata in ChainFire const PREFIX_STORES: &str = "/flaredb/stores/"; const PREFIX_REGIONS: &str = "/flaredb/regions/"; const KEY_CLUSTER_ID: &str = "/flaredb/cluster/id"; /// Store information stored in ChainFire #[derive(Debug, Clone, Serialize, Deserialize)] pub struct StoreInfo { pub id: u64, pub addr: String, pub last_heartbeat: u64, } /// Region information stored in ChainFire #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RegionInfo { pub id: u64, pub start_key: Vec, pub end_key: Vec, pub peers: Vec, pub leader_id: u64, } /// Events emitted by the PD client when metadata changes #[derive(Debug, Clone)] pub enum PdEvent { /// A store was added or updated StoreUpdated(StoreInfo), /// A store was removed StoreRemoved(u64), /// A region was added or updated RegionUpdated(RegionInfo), /// A region was removed RegionRemoved(u64), } /// Cached metadata for faster lookups struct MetadataCache { stores: HashMap, regions: HashMap, } impl MetadataCache { fn new() -> Self { Self { stores: HashMap::new(), regions: HashMap::new(), } } } /// Client for interacting with ChainFire as a Placement Driver pub struct PdClient { kv_client: KvClient, watch_client: WatchClient, cluster_id: u64, /// Cached metadata for fast access cache: Arc>, /// Channel to receive metadata change events event_tx: broadcast::Sender, } impl PdClient { async fn select_leader_addr(addrs: &[String]) -> Option { let mut member_addrs = HashMap::new(); let mut leader_id = None; for addr in addrs { let endpoint = if addr.starts_with("http") { addr.clone() } else { format!("http://{}", addr) }; let channel = match Channel::from_shared(endpoint) { Ok(channel) => match channel.connect().await { Ok(channel) => channel, Err(_) => continue, }, Err(_) => continue, }; let mut cluster_client = ClusterClient::new(channel); let status = match cluster_client.status(StatusRequest {}).await { Ok(resp) => resp.into_inner(), Err(_) => continue, }; let member_id = status.header.as_ref().map(|h| h.member_id).unwrap_or(0); if member_id != 0 { member_addrs.insert(member_id, addr.clone()); } if status.leader != 0 { leader_id = Some(status.leader); if status.leader == member_id && member_id != 0 { return Some(addr.clone()); } } } leader_id.and_then(|id| member_addrs.get(&id).cloned()) } /// Connect to a ChainFire cluster pub async fn connect(addr: String) -> Result> { let endpoint = if addr.starts_with("http") { addr } else { format!("http://{}", addr) }; let channel = Channel::from_shared(endpoint)?.connect().await?; let kv_client = KvClient::new(channel.clone()); let watch_client = WatchClient::new(channel); let (event_tx, _) = broadcast::channel(256); let mut client = Self { kv_client, watch_client, cluster_id: 0, cache: Arc::new(RwLock::new(MetadataCache::new())), event_tx, }; // Try to get cluster ID, or generate a new one client.cluster_id = client.get_or_init_cluster_id().await?; // Load initial metadata into cache client.refresh_cache().await?; Ok(client) } /// Connect to the first reachable ChainFire endpoint that can serve PD metadata. pub async fn connect_any( addrs: &[String], ) -> Result> { if let Some(leader_addr) = Self::select_leader_addr(addrs).await { return Self::connect(leader_addr).await; } let mut last_error: Option = None; for addr in addrs { match Self::connect(addr.clone()).await { Ok(client) => return Ok(client), Err(err) => { last_error = Some(format!("{}: {}", addr, err)); } } } Err(last_error .unwrap_or_else(|| "no PD endpoints configured".to_string()) .into()) } /// Subscribe to metadata change events pub fn subscribe(&self) -> broadcast::Receiver { self.event_tx.subscribe() } /// Refresh the local cache from ChainFire async fn refresh_cache(&mut self) -> Result<(), Box> { let stores = self.list_stores_remote().await?; let regions = self.list_regions_remote().await?; let mut cache = self.cache.write().await; cache.stores.clear(); for store in stores { cache.stores.insert(store.id, store); } cache.regions.clear(); for region in regions { cache.regions.insert(region.id, region); } Ok(()) } /// Start watching for metadata changes in the background pub async fn start_watch( &mut self, ) -> Result<(), Box> { let (tx, rx) = mpsc::channel::(32); // Create watch requests for stores and regions prefixes let stores_watch = WatchRequest { request_union: Some( flaredb_proto::chainfire::watch_request::RequestUnion::CreateRequest( WatchCreateRequest { key: PREFIX_STORES.as_bytes().to_vec(), range_end: prefix_range_end(PREFIX_STORES), start_revision: 0, progress_notify: false, prev_kv: true, watch_id: 1, }, ), ), }; let regions_watch = WatchRequest { request_union: Some( flaredb_proto::chainfire::watch_request::RequestUnion::CreateRequest( WatchCreateRequest { key: PREFIX_REGIONS.as_bytes().to_vec(), range_end: prefix_range_end(PREFIX_REGIONS), start_revision: 0, progress_notify: false, prev_kv: true, watch_id: 2, }, ), ), }; // Send initial watch requests tx.send(stores_watch).await.ok(); tx.send(regions_watch).await.ok(); // Convert mpsc receiver to stream let request_stream = tokio_stream::wrappers::ReceiverStream::new(rx); // Start watching let response = self.watch_client.watch(request_stream).await?; let mut stream = response.into_inner(); let cache = self.cache.clone(); let event_tx = self.event_tx.clone(); // Spawn background task to process watch events tokio::spawn(async move { while let Ok(Some(resp)) = stream.message().await { for event in resp.events { if let Err(e) = process_watch_event(&event, &cache, &event_tx).await { tracing::warn!("Failed to process watch event: {}", e); } } } tracing::info!("PD watch stream ended"); }); Ok(()) } /// Get or initialize the cluster ID async fn get_or_init_cluster_id( &mut self, ) -> Result> { let req = RangeRequest { key: KEY_CLUSTER_ID.as_bytes().to_vec(), range_end: Vec::new(), limit: 1, revision: 0, keys_only: false, count_only: false, }; let resp = self.kv_client.range(req).await?.into_inner(); if let Some(kv) = resp.kvs.first() { let id_str = String::from_utf8_lossy(&kv.value); Ok(id_str.parse().unwrap_or(1)) } else { // Initialize with cluster ID 1 let put_req = PutRequest { key: KEY_CLUSTER_ID.as_bytes().to_vec(), value: b"1".to_vec(), lease: 0, prev_kv: false, }; self.kv_client.put(put_req).await?; Ok(1) } } /// Get the cluster ID pub fn cluster_id(&self) -> u64 { self.cluster_id } /// Register a store with the cluster pub async fn register_store( &mut self, store_id: u64, addr: String, ) -> Result<(), Box> { let info = StoreInfo { id: store_id, addr, last_heartbeat: std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() .as_secs(), }; let key = format!("{}{}", PREFIX_STORES, store_id); let value = serde_json::to_vec(&info)?; let req = PutRequest { key: key.into_bytes(), value, lease: 0, prev_kv: false, }; self.kv_client.put(req).await?; Ok(()) } /// Send a heartbeat for a store pub async fn heartbeat( &mut self, store_id: u64, addr: String, ) -> Result<(), Box> { // Re-register with updated timestamp self.register_store(store_id, addr).await } /// List all registered stores from cache pub async fn list_stores(&self) -> Vec { let cache = self.cache.read().await; cache.stores.values().cloned().collect() } /// List all registered stores from remote (ChainFire) async fn list_stores_remote( &mut self, ) -> Result, Box> { let key = PREFIX_STORES.as_bytes().to_vec(); let range_end = prefix_range_end(PREFIX_STORES); let req = RangeRequest { key, range_end, limit: 0, revision: 0, keys_only: false, count_only: false, }; let resp = self.kv_client.range(req).await?.into_inner(); let mut stores = Vec::new(); for kv in resp.kvs { if let Ok(info) = serde_json::from_slice::(&kv.value) { stores.push(info); } } Ok(stores) } /// Register or update a region pub async fn put_region( &mut self, region: RegionInfo, ) -> Result<(), Box> { let key = format!("{}{}", PREFIX_REGIONS, region.id); let value = serde_json::to_vec(®ion)?; let req = PutRequest { key: key.into_bytes(), value, lease: 0, prev_kv: false, }; self.kv_client.put(req).await?; Ok(()) } /// List all regions from cache pub async fn list_regions(&self) -> Vec { let cache = self.cache.read().await; cache.regions.values().cloned().collect() } /// List all regions from remote (ChainFire) async fn list_regions_remote( &mut self, ) -> Result, Box> { let key = PREFIX_REGIONS.as_bytes().to_vec(); let range_end = prefix_range_end(PREFIX_REGIONS); let req = RangeRequest { key, range_end, limit: 0, revision: 0, keys_only: false, count_only: false, }; let resp = self.kv_client.range(req).await?.into_inner(); let mut regions = Vec::new(); for kv in resp.kvs { if let Ok(info) = serde_json::from_slice::(&kv.value) { regions.push(info); } } Ok(regions) } /// Get a region by ID pub async fn get_region( &mut self, region_id: u64, ) -> Result, Box> { let key = format!("{}{}", PREFIX_REGIONS, region_id); let req = RangeRequest { key: key.into_bytes(), range_end: Vec::new(), limit: 1, revision: 0, keys_only: false, count_only: false, }; let resp = self.kv_client.range(req).await?.into_inner(); if let Some(kv) = resp.kvs.first() { Ok(serde_json::from_slice(&kv.value).ok()) } else { Ok(None) } } /// Find the region containing a key (from cache) pub async fn get_region_for_key(&self, key: &[u8]) -> Option { let regions = self.list_regions().await; for region in regions { let start_ok = key >= region.start_key.as_slice(); let end_ok = region.end_key.is_empty() || key < region.end_key.as_slice(); if start_ok && end_ok { return Some(region); } } None } /// Get the leader address for a region (from cache) pub async fn get_leader_addr(&self, region_id: u64) -> Option { let cache = self.cache.read().await; if let Some(region) = cache.regions.get(®ion_id) { if let Some(store) = cache.stores.get(®ion.leader_id) { return Some(store.addr.clone()); } } None } /// Report that this store is the leader for a region pub async fn report_leader( &mut self, region_id: u64, leader_id: u64, ) -> Result<(), Box> { if let Some(mut region) = self.get_region(region_id).await? { region.leader_id = leader_id; self.put_region(region).await?; } Ok(()) } /// Initialize default region if none exist pub async fn init_default_region( &mut self, peers: Vec, ) -> Result<(), Box> { let regions = self.list_regions().await; if regions.is_empty() { let region = RegionInfo { id: 1, start_key: Vec::new(), end_key: Vec::new(), peers, leader_id: 0, }; self.put_region(region).await?; } Ok(()) } /// Get a region from cache by ID pub async fn get_region_cached(&self, region_id: u64) -> Option { let cache = self.cache.read().await; cache.regions.get(®ion_id).cloned() } /// Get a store from cache by ID pub async fn get_store_cached(&self, store_id: u64) -> Option { let cache = self.cache.read().await; cache.stores.get(&store_id).cloned() } } // ============================================================================ // Helper functions // ============================================================================ /// Compute the range end for a prefix (etcd-style prefix matching) fn prefix_range_end(prefix: &str) -> Vec { let mut end = prefix.as_bytes().to_vec(); if let Some(last) = end.last_mut() { *last += 1; } end } /// Process a watch event and update cache async fn process_watch_event( event: &Event, cache: &Arc>, event_tx: &broadcast::Sender, ) -> Result<(), Box> { let kv = match &event.kv { Some(kv) => kv, None => return Ok(()), }; let key_str = String::from_utf8_lossy(&kv.key); let is_delete = event.r#type == 1; // DELETE = 1 if key_str.starts_with(PREFIX_STORES) { let store_id_str = key_str.strip_prefix(PREFIX_STORES).unwrap_or(""); if let Ok(store_id) = store_id_str.parse::() { let mut cache = cache.write().await; if is_delete { cache.stores.remove(&store_id); let _ = event_tx.send(PdEvent::StoreRemoved(store_id)); } else if let Ok(info) = serde_json::from_slice::(&kv.value) { cache.stores.insert(store_id, info.clone()); let _ = event_tx.send(PdEvent::StoreUpdated(info)); } } } else if key_str.starts_with(PREFIX_REGIONS) { let region_id_str = key_str.strip_prefix(PREFIX_REGIONS).unwrap_or(""); if let Ok(region_id) = region_id_str.parse::() { let mut cache = cache.write().await; if is_delete { cache.regions.remove(®ion_id); let _ = event_tx.send(PdEvent::RegionRemoved(region_id)); } else if let Ok(info) = serde_json::from_slice::(&kv.value) { cache.regions.insert(region_id, info.clone()); let _ = event_tx.send(PdEvent::RegionUpdated(info)); } } } Ok(()) }