photoncloud-monorepo/flaredb/crates/flaredb-server/src/pd_client.rs

568 lines
18 KiB
Rust

//! ChainFire-based Placement Driver client
//!
//! This module provides a client for interacting with ChainFire as the
//! Placement Driver (PD) for FlareDB cluster management.
//!
//! ## Watch Integration
//!
//! The client supports real-time notifications of metadata changes via
//! ChainFire's Watch API, enabling event-driven updates instead of polling.
use flaredb_proto::chainfire::cluster_client::ClusterClient;
use flaredb_proto::chainfire::kv_client::KvClient;
use flaredb_proto::chainfire::watch_client::WatchClient;
use flaredb_proto::chainfire::{
Event, PutRequest, RangeRequest, StatusRequest, WatchCreateRequest, WatchRequest,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{broadcast, mpsc, RwLock};
use tonic::transport::Channel;
/// Key prefixes for cluster metadata in ChainFire
const PREFIX_STORES: &str = "/flaredb/stores/";
const PREFIX_REGIONS: &str = "/flaredb/regions/";
const KEY_CLUSTER_ID: &str = "/flaredb/cluster/id";
/// Store information stored in ChainFire
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoreInfo {
pub id: u64,
pub addr: String,
pub last_heartbeat: u64,
}
/// Region information stored in ChainFire
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegionInfo {
pub id: u64,
pub start_key: Vec<u8>,
pub end_key: Vec<u8>,
pub peers: Vec<u64>,
pub leader_id: u64,
}
/// Events emitted by the PD client when metadata changes
#[derive(Debug, Clone)]
pub enum PdEvent {
/// A store was added or updated
StoreUpdated(StoreInfo),
/// A store was removed
StoreRemoved(u64),
/// A region was added or updated
RegionUpdated(RegionInfo),
/// A region was removed
RegionRemoved(u64),
}
/// Cached metadata for faster lookups
struct MetadataCache {
stores: HashMap<u64, StoreInfo>,
regions: HashMap<u64, RegionInfo>,
}
impl MetadataCache {
fn new() -> Self {
Self {
stores: HashMap::new(),
regions: HashMap::new(),
}
}
}
/// Client for interacting with ChainFire as a Placement Driver
pub struct PdClient {
kv_client: KvClient<Channel>,
watch_client: WatchClient<Channel>,
cluster_id: u64,
/// Cached metadata for fast access
cache: Arc<RwLock<MetadataCache>>,
/// Channel to receive metadata change events
event_tx: broadcast::Sender<PdEvent>,
}
impl PdClient {
async fn select_leader_addr(addrs: &[String]) -> Option<String> {
let mut member_addrs = HashMap::new();
let mut leader_id = None;
for addr in addrs {
let endpoint = if addr.starts_with("http") {
addr.clone()
} else {
format!("http://{}", addr)
};
let channel = match Channel::from_shared(endpoint) {
Ok(channel) => match channel.connect().await {
Ok(channel) => channel,
Err(_) => continue,
},
Err(_) => continue,
};
let mut cluster_client = ClusterClient::new(channel);
let status = match cluster_client.status(StatusRequest {}).await {
Ok(resp) => resp.into_inner(),
Err(_) => continue,
};
let member_id = status.header.as_ref().map(|h| h.member_id).unwrap_or(0);
if member_id != 0 {
member_addrs.insert(member_id, addr.clone());
}
if status.leader != 0 {
leader_id = Some(status.leader);
if status.leader == member_id && member_id != 0 {
return Some(addr.clone());
}
}
}
leader_id.and_then(|id| member_addrs.get(&id).cloned())
}
/// Connect to a ChainFire cluster
pub async fn connect(addr: String) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let endpoint = if addr.starts_with("http") {
addr
} else {
format!("http://{}", addr)
};
let channel = Channel::from_shared(endpoint)?.connect().await?;
let kv_client = KvClient::new(channel.clone());
let watch_client = WatchClient::new(channel);
let (event_tx, _) = broadcast::channel(256);
let mut client = Self {
kv_client,
watch_client,
cluster_id: 0,
cache: Arc::new(RwLock::new(MetadataCache::new())),
event_tx,
};
// Try to get cluster ID, or generate a new one
client.cluster_id = client.get_or_init_cluster_id().await?;
// Load initial metadata into cache
client.refresh_cache().await?;
Ok(client)
}
/// Connect to the first reachable ChainFire endpoint that can serve PD metadata.
pub async fn connect_any(
addrs: &[String],
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
if let Some(leader_addr) = Self::select_leader_addr(addrs).await {
return Self::connect(leader_addr).await;
}
let mut last_error: Option<String> = None;
for addr in addrs {
match Self::connect(addr.clone()).await {
Ok(client) => return Ok(client),
Err(err) => {
last_error = Some(format!("{}: {}", addr, err));
}
}
}
Err(last_error
.unwrap_or_else(|| "no PD endpoints configured".to_string())
.into())
}
/// Subscribe to metadata change events
pub fn subscribe(&self) -> broadcast::Receiver<PdEvent> {
self.event_tx.subscribe()
}
/// Refresh the local cache from ChainFire
async fn refresh_cache(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let stores = self.list_stores_remote().await?;
let regions = self.list_regions_remote().await?;
let mut cache = self.cache.write().await;
cache.stores.clear();
for store in stores {
cache.stores.insert(store.id, store);
}
cache.regions.clear();
for region in regions {
cache.regions.insert(region.id, region);
}
Ok(())
}
/// Start watching for metadata changes in the background
pub async fn start_watch(
&mut self,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let (tx, rx) = mpsc::channel::<WatchRequest>(32);
// Create watch requests for stores and regions prefixes
let stores_watch = WatchRequest {
request_union: Some(
flaredb_proto::chainfire::watch_request::RequestUnion::CreateRequest(
WatchCreateRequest {
key: PREFIX_STORES.as_bytes().to_vec(),
range_end: prefix_range_end(PREFIX_STORES),
start_revision: 0,
progress_notify: false,
prev_kv: true,
watch_id: 1,
},
),
),
};
let regions_watch = WatchRequest {
request_union: Some(
flaredb_proto::chainfire::watch_request::RequestUnion::CreateRequest(
WatchCreateRequest {
key: PREFIX_REGIONS.as_bytes().to_vec(),
range_end: prefix_range_end(PREFIX_REGIONS),
start_revision: 0,
progress_notify: false,
prev_kv: true,
watch_id: 2,
},
),
),
};
// Send initial watch requests
tx.send(stores_watch).await.ok();
tx.send(regions_watch).await.ok();
// Convert mpsc receiver to stream
let request_stream = tokio_stream::wrappers::ReceiverStream::new(rx);
// Start watching
let response = self.watch_client.watch(request_stream).await?;
let mut stream = response.into_inner();
let cache = self.cache.clone();
let event_tx = self.event_tx.clone();
// Spawn background task to process watch events
tokio::spawn(async move {
while let Ok(Some(resp)) = stream.message().await {
for event in resp.events {
if let Err(e) = process_watch_event(&event, &cache, &event_tx).await {
tracing::warn!("Failed to process watch event: {}", e);
}
}
}
tracing::info!("PD watch stream ended");
});
Ok(())
}
/// Get or initialize the cluster ID
async fn get_or_init_cluster_id(
&mut self,
) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
let req = RangeRequest {
key: KEY_CLUSTER_ID.as_bytes().to_vec(),
range_end: Vec::new(),
limit: 1,
revision: 0,
keys_only: false,
count_only: false,
};
let resp = self.kv_client.range(req).await?.into_inner();
if let Some(kv) = resp.kvs.first() {
let id_str = String::from_utf8_lossy(&kv.value);
Ok(id_str.parse().unwrap_or(1))
} else {
// Initialize with cluster ID 1
let put_req = PutRequest {
key: KEY_CLUSTER_ID.as_bytes().to_vec(),
value: b"1".to_vec(),
lease: 0,
prev_kv: false,
};
self.kv_client.put(put_req).await?;
Ok(1)
}
}
/// Get the cluster ID
pub fn cluster_id(&self) -> u64 {
self.cluster_id
}
/// Register a store with the cluster
pub async fn register_store(
&mut self,
store_id: u64,
addr: String,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let info = StoreInfo {
id: store_id,
addr,
last_heartbeat: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
};
let key = format!("{}{}", PREFIX_STORES, store_id);
let value = serde_json::to_vec(&info)?;
let req = PutRequest {
key: key.into_bytes(),
value,
lease: 0,
prev_kv: false,
};
self.kv_client.put(req).await?;
Ok(())
}
/// Send a heartbeat for a store
pub async fn heartbeat(
&mut self,
store_id: u64,
addr: String,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// Re-register with updated timestamp
self.register_store(store_id, addr).await
}
/// List all registered stores from cache
pub async fn list_stores(&self) -> Vec<StoreInfo> {
let cache = self.cache.read().await;
cache.stores.values().cloned().collect()
}
/// List all registered stores from remote (ChainFire)
async fn list_stores_remote(
&mut self,
) -> Result<Vec<StoreInfo>, Box<dyn std::error::Error + Send + Sync>> {
let key = PREFIX_STORES.as_bytes().to_vec();
let range_end = prefix_range_end(PREFIX_STORES);
let req = RangeRequest {
key,
range_end,
limit: 0,
revision: 0,
keys_only: false,
count_only: false,
};
let resp = self.kv_client.range(req).await?.into_inner();
let mut stores = Vec::new();
for kv in resp.kvs {
if let Ok(info) = serde_json::from_slice::<StoreInfo>(&kv.value) {
stores.push(info);
}
}
Ok(stores)
}
/// Register or update a region
pub async fn put_region(
&mut self,
region: RegionInfo,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let key = format!("{}{}", PREFIX_REGIONS, region.id);
let value = serde_json::to_vec(&region)?;
let req = PutRequest {
key: key.into_bytes(),
value,
lease: 0,
prev_kv: false,
};
self.kv_client.put(req).await?;
Ok(())
}
/// List all regions from cache
pub async fn list_regions(&self) -> Vec<RegionInfo> {
let cache = self.cache.read().await;
cache.regions.values().cloned().collect()
}
/// List all regions from remote (ChainFire)
async fn list_regions_remote(
&mut self,
) -> Result<Vec<RegionInfo>, Box<dyn std::error::Error + Send + Sync>> {
let key = PREFIX_REGIONS.as_bytes().to_vec();
let range_end = prefix_range_end(PREFIX_REGIONS);
let req = RangeRequest {
key,
range_end,
limit: 0,
revision: 0,
keys_only: false,
count_only: false,
};
let resp = self.kv_client.range(req).await?.into_inner();
let mut regions = Vec::new();
for kv in resp.kvs {
if let Ok(info) = serde_json::from_slice::<RegionInfo>(&kv.value) {
regions.push(info);
}
}
Ok(regions)
}
/// Get a region by ID
pub async fn get_region(
&mut self,
region_id: u64,
) -> Result<Option<RegionInfo>, Box<dyn std::error::Error + Send + Sync>> {
let key = format!("{}{}", PREFIX_REGIONS, region_id);
let req = RangeRequest {
key: key.into_bytes(),
range_end: Vec::new(),
limit: 1,
revision: 0,
keys_only: false,
count_only: false,
};
let resp = self.kv_client.range(req).await?.into_inner();
if let Some(kv) = resp.kvs.first() {
Ok(serde_json::from_slice(&kv.value).ok())
} else {
Ok(None)
}
}
/// Find the region containing a key (from cache)
pub async fn get_region_for_key(&self, key: &[u8]) -> Option<RegionInfo> {
let regions = self.list_regions().await;
for region in regions {
let start_ok = key >= region.start_key.as_slice();
let end_ok = region.end_key.is_empty() || key < region.end_key.as_slice();
if start_ok && end_ok {
return Some(region);
}
}
None
}
/// Get the leader address for a region (from cache)
pub async fn get_leader_addr(&self, region_id: u64) -> Option<String> {
let cache = self.cache.read().await;
if let Some(region) = cache.regions.get(&region_id) {
if let Some(store) = cache.stores.get(&region.leader_id) {
return Some(store.addr.clone());
}
}
None
}
/// Report that this store is the leader for a region
pub async fn report_leader(
&mut self,
region_id: u64,
leader_id: u64,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
if let Some(mut region) = self.get_region(region_id).await? {
region.leader_id = leader_id;
self.put_region(region).await?;
}
Ok(())
}
/// Initialize default region if none exist
pub async fn init_default_region(
&mut self,
peers: Vec<u64>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let regions = self.list_regions().await;
if regions.is_empty() {
let region = RegionInfo {
id: 1,
start_key: Vec::new(),
end_key: Vec::new(),
peers,
leader_id: 0,
};
self.put_region(region).await?;
}
Ok(())
}
/// Get a region from cache by ID
pub async fn get_region_cached(&self, region_id: u64) -> Option<RegionInfo> {
let cache = self.cache.read().await;
cache.regions.get(&region_id).cloned()
}
/// Get a store from cache by ID
pub async fn get_store_cached(&self, store_id: u64) -> Option<StoreInfo> {
let cache = self.cache.read().await;
cache.stores.get(&store_id).cloned()
}
}
// ============================================================================
// Helper functions
// ============================================================================
/// Compute the range end for a prefix (etcd-style prefix matching)
fn prefix_range_end(prefix: &str) -> Vec<u8> {
let mut end = prefix.as_bytes().to_vec();
if let Some(last) = end.last_mut() {
*last += 1;
}
end
}
/// Process a watch event and update cache
async fn process_watch_event(
event: &Event,
cache: &Arc<RwLock<MetadataCache>>,
event_tx: &broadcast::Sender<PdEvent>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let kv = match &event.kv {
Some(kv) => kv,
None => return Ok(()),
};
let key_str = String::from_utf8_lossy(&kv.key);
let is_delete = event.r#type == 1; // DELETE = 1
if key_str.starts_with(PREFIX_STORES) {
let store_id_str = key_str.strip_prefix(PREFIX_STORES).unwrap_or("");
if let Ok(store_id) = store_id_str.parse::<u64>() {
let mut cache = cache.write().await;
if is_delete {
cache.stores.remove(&store_id);
let _ = event_tx.send(PdEvent::StoreRemoved(store_id));
} else if let Ok(info) = serde_json::from_slice::<StoreInfo>(&kv.value) {
cache.stores.insert(store_id, info.clone());
let _ = event_tx.send(PdEvent::StoreUpdated(info));
}
}
} else if key_str.starts_with(PREFIX_REGIONS) {
let region_id_str = key_str.strip_prefix(PREFIX_REGIONS).unwrap_or("");
if let Ok(region_id) = region_id_str.parse::<u64>() {
let mut cache = cache.write().await;
if is_delete {
cache.regions.remove(&region_id);
let _ = event_tx.send(PdEvent::RegionRemoved(region_id));
} else if let Ok(info) = serde_json::from_slice::<RegionInfo>(&kv.value) {
cache.regions.insert(region_id, info.clone());
let _ = event_tx.send(PdEvent::RegionUpdated(info));
}
}
}
Ok(())
}