use flaredb_proto::chainfire::kv_client::KvClient as ChainfireKvClient; use flaredb_proto::chainfire::RangeRequest as ChainfireRangeRequest; use flaredb_proto::kvrpc::kv_cas_client::KvCasClient; use flaredb_proto::kvrpc::kv_raw_client::KvRawClient; use flaredb_proto::kvrpc::{ CasRequest, DeleteRequest, GetRequest, RawDeleteRequest, RawGetRequest, RawPutRequest, RawScanRequest, }; use flaredb_proto::pdpb::Store; use std::collections::HashMap; use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; use serde::Deserialize; use tokio::sync::Mutex; use tonic::transport::Channel; use flaredb_proto::pdpb::pd_client::PdClient; use flaredb_proto::pdpb::tso_client::TsoClient; use flaredb_proto::pdpb::{GetRegionRequest, Region, TsoRequest}; use std::future::Future; use std::time::Duration; use crate::region_cache::RegionCache; pub struct RdbClient { // We need a map of addr -> Channel/Client to reuse connections // Or just create on fly for MVP? Connection creation is expensive. // Let's cache channels. channels: Arc>>, direct_addr: Option, // Clients for placement and routing metadata. tso_client: Option>, pd_client: Option>, chainfire_kv_client: Option>, region_cache: RegionCache, namespace: String, } #[derive(Debug, Clone, Deserialize)] struct ChainfireStoreInfo { id: u64, addr: String, } #[derive(Debug, Clone, Deserialize)] struct ChainfireRegionInfo { id: u64, start_key: Vec, end_key: Vec, peers: Vec, leader_id: u64, } impl RdbClient { const ROUTE_RETRY_LIMIT: usize = 12; const ROUTE_RETRY_BASE_DELAY_MS: u64 = 100; const ROUTED_RPC_TIMEOUT: Duration = Duration::from_secs(1); pub async fn connect_with_pd( _server_addr: String, pd_addr: String, ) -> Result { Self::connect_with_pd_namespace(_server_addr, pd_addr, String::new()).await } pub async fn connect_with_pd_namespace( server_addr: String, pd_addr: String, namespace: impl Into, ) -> Result { // A number of in-repo callers still pass the same address for both server and PD. // In that case, prefer direct routing and skip the PD lookup path entirely. let direct_addr = if !server_addr.is_empty() && server_addr == pd_addr { Some(server_addr) } else { None }; let (tso_client, pd_client, chainfire_kv_client) = if direct_addr.is_some() { (None, None, None) } else { let pd_channel = Channel::from_shared(transport_endpoint(&pd_addr)) .unwrap() .connect() .await?; let mut probe_client = PdClient::new(pd_channel.clone()); let probe = probe_client .get_region(GetRegionRequest { key: Vec::new() }) .await; match probe { Err(status) if status.code() == tonic::Code::Unimplemented => ( None, None, Some(ChainfireKvClient::new(pd_channel)), ), _ => ( Some(TsoClient::new(pd_channel.clone())), Some(PdClient::new(pd_channel)), None, ), } }; Ok(Self { channels: Arc::new(Mutex::new(HashMap::new())), direct_addr, tso_client, pd_client, chainfire_kv_client, region_cache: RegionCache::new(), namespace: namespace.into(), }) } /// Connect directly to a single FlareDB server without PD/region lookup. pub async fn connect_direct( server_addr: String, namespace: impl Into, ) -> Result { let ep = transport_endpoint(&server_addr); let channel = Channel::from_shared(ep).unwrap().connect().await?; Ok(Self { channels: Arc::new(Mutex::new(HashMap::new())), direct_addr: Some(server_addr), tso_client: Some(TsoClient::new(channel.clone())), pd_client: Some(PdClient::new(channel)), chainfire_kv_client: None, region_cache: RegionCache::new(), namespace: namespace.into(), }) } async fn resolve_addr(&self, key: &[u8]) -> Result { if let Some(addr) = &self.direct_addr { return Ok(addr.clone()); } if let Some(addr) = self.region_cache.get_store_addr(key).await { return Ok(addr); } if let Some(chainfire_kv_client) = &self.chainfire_kv_client { return self.resolve_addr_via_chainfire(key, chainfire_kv_client.clone()).await; } if let Some(pd_client) = &self.pd_client { let mut pd_c = pd_client.clone(); let req = GetRegionRequest { key: key.to_vec() }; let resp = pd_c.get_region(req).await?.into_inner(); if let (Some(region), Some(leader)) = (resp.region, resp.leader) { self.region_cache.update(region, leader.clone()).await; return Ok(leader.addr); } } Err(tonic::Status::not_found("region not found")) } async fn resolve_addr_uncached(&self, key: &[u8]) -> Result { if let Some(addr) = &self.direct_addr { return Ok(addr.clone()); } self.region_cache.clear().await; if let Some(chainfire_kv_client) = &self.chainfire_kv_client { return self.resolve_addr_via_chainfire(key, chainfire_kv_client.clone()).await; } if let Some(pd_client) = &self.pd_client { let mut pd_c = pd_client.clone(); let req = GetRegionRequest { key: key.to_vec() }; let resp = pd_c.get_region(req).await?.into_inner(); if let (Some(region), Some(leader)) = (resp.region, resp.leader) { self.region_cache.update(region, leader.clone()).await; return Ok(leader.addr); } } Err(tonic::Status::not_found("region not found")) } async fn get_channel(&self, addr: &str) -> Result { Self::get_channel_from_map(&self.channels, addr).await } async fn get_channel_from_map( channels: &Arc>>, addr: &str, ) -> Result { let mut map = channels.lock().await; if let Some(chan) = map.get(addr) { return Ok(chan.clone()); } let ep = transport_endpoint(addr); let chan = Channel::from_shared(ep).unwrap().connect().await?; map.insert(addr.to_string(), chan.clone()); Ok(chan) } async fn evict_channel_from_map(channels: &Arc>>, addr: &str) { let mut map = channels.lock().await; map.remove(addr); } async fn with_routed_addr(&self, key: &[u8], mut op: F) -> Result where F: FnMut(String) -> Fut, Fut: Future>, { let mut addr = self.resolve_addr(key).await?; let mut refreshed = false; let mut last_status = None; for attempt in 0..Self::ROUTE_RETRY_LIMIT { match tokio::time::timeout(Self::ROUTED_RPC_TIMEOUT, op(addr.clone())).await { Err(_) => { Self::evict_channel_from_map(&self.channels, &addr).await; let status = tonic::Status::unavailable(format!( "transport error: routed request to {} timed out after {}ms", addr, Self::ROUTED_RPC_TIMEOUT.as_millis() )); if !refreshed && self.direct_addr.is_none() { refreshed = true; if let Ok(fresh_addr) = self.resolve_addr_uncached(key).await { addr = fresh_addr; last_status = Some(status); tokio::time::sleep(Self::retry_delay(attempt)).await; continue; } } last_status = Some(status); tokio::time::sleep(Self::retry_delay(attempt)).await; continue; } Ok(Err(status)) => { let transport_error = Self::is_transport_error(&status); if transport_error { Self::evict_channel_from_map(&self.channels, &addr).await; } if let Some(redirect_addr) = Self::extract_forward_addr(status.message()) { self.region_cache .override_store_addr(key, redirect_addr.clone()) .await; if redirect_addr != addr { addr = redirect_addr; last_status = Some(status); tokio::time::sleep(Self::retry_delay(attempt)).await; continue; } } if !refreshed && self.direct_addr.is_none() && Self::is_retryable_route_error(&status) { refreshed = true; if let Ok(fresh_addr) = self.resolve_addr_uncached(key).await { addr = fresh_addr; last_status = Some(status); tokio::time::sleep(Self::retry_delay(attempt)).await; continue; } } if transport_error { last_status = Some(status); tokio::time::sleep(Self::retry_delay(attempt)).await; continue; } if Self::is_retryable_route_error(&status) { last_status = Some(status); tokio::time::sleep(Self::retry_delay(attempt)).await; continue; } return Err(status); } Ok(Ok(value)) => return Ok(value), } } Err(last_status.unwrap_or_else(|| tonic::Status::internal("routing retry exhausted"))) } fn is_retryable_route_error(status: &tonic::Status) -> bool { if !matches!( status.code(), tonic::Code::FailedPrecondition | tonic::Code::Unavailable | tonic::Code::Internal | tonic::Code::Unknown ) { return false; } let message = status.message(); message.contains("forward request") || message.contains("redirect required") || message.contains("Linearizable read failed") || message.contains("not leader") || message.contains("NotLeader") || message.contains("leader_id: None") || message.contains("transport error") } fn retry_delay(attempt: usize) -> Duration { Duration::from_millis( Self::ROUTE_RETRY_BASE_DELAY_MS.saturating_mul((attempt as u64) + 1), ) } fn is_transport_error(status: &tonic::Status) -> bool { matches!( status.code(), tonic::Code::Unavailable | tonic::Code::Internal | tonic::Code::Unknown ) && status.message().contains("transport error") } fn extract_forward_addr(message: &str) -> Option { const ADDR_MARKER: &str = "addr: \""; let start = message.find(ADDR_MARKER)? + ADDR_MARKER.len(); let end = message[start..].find('"')?; Some(message[start..start + end].to_string()) } pub async fn get_tso(&mut self) -> Result { if self.chainfire_kv_client.is_some() { return Ok(SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap_or_default() .as_millis() as u64); } let Some(tso_client) = &mut self.tso_client else { return Err(tonic::Status::failed_precondition( "timestamp oracle unavailable in direct mode", )); }; let req = TsoRequest { count: 1 }; let resp = tso_client.get_timestamp(req).await?.into_inner(); Ok(resp.timestamp) } pub async fn raw_put(&mut self, key: Vec, value: Vec) -> Result<(), tonic::Status> { let route_key = key.clone(); let channels = Arc::clone(&self.channels); let namespace = self.namespace.clone(); self.with_routed_addr(&route_key, |addr| { let channels = Arc::clone(&channels); let key = key.clone(); let value = value.clone(); let namespace = namespace.clone(); async move { let channel = Self::get_channel_from_map(&channels, &addr) .await .map_err(|e| tonic::Status::internal(e.to_string()))?; let mut client = KvRawClient::new(channel); let req = RawPutRequest { key, value, namespace, }; client.raw_put(req).await?; Ok(()) } }) .await } pub async fn raw_get(&mut self, key: Vec) -> Result>, tonic::Status> { let route_key = key.clone(); let channels = Arc::clone(&self.channels); let namespace = self.namespace.clone(); self.with_routed_addr(&route_key, |addr| { let channels = Arc::clone(&channels); let key = key.clone(); let namespace = namespace.clone(); async move { let channel = Self::get_channel_from_map(&channels, &addr) .await .map_err(|e| tonic::Status::internal(e.to_string()))?; let mut client = KvRawClient::new(channel); let req = RawGetRequest { key, namespace }; let resp = client.raw_get(req).await?.into_inner(); if resp.found { Ok(Some(resp.value)) } else { Ok(None) } } }) .await } pub async fn raw_delete(&mut self, key: Vec) -> Result { let route_key = key.clone(); let channels = Arc::clone(&self.channels); let namespace = self.namespace.clone(); self.with_routed_addr(&route_key, |addr| { let channels = Arc::clone(&channels); let key = key.clone(); let namespace = namespace.clone(); async move { let channel = Self::get_channel_from_map(&channels, &addr) .await .map_err(|e| tonic::Status::internal(e.to_string()))?; let mut client = KvRawClient::new(channel); let req = RawDeleteRequest { key, namespace }; let resp = client.raw_delete(req).await?.into_inner(); Ok(resp.existed) } }) .await } /// Scan a range of keys in raw (eventual consistency) mode. /// /// Returns (keys, values, next_key if has_more). pub async fn raw_scan( &mut self, start_key: Vec, end_key: Vec, limit: u32, ) -> Result<(Vec>, Vec>, Option>), tonic::Status> { let route_key = start_key.clone(); let channels = Arc::clone(&self.channels); let namespace = self.namespace.clone(); self.with_routed_addr(&route_key, |addr| { let channels = Arc::clone(&channels); let start_key = start_key.clone(); let end_key = end_key.clone(); let namespace = namespace.clone(); async move { let channel = Self::get_channel_from_map(&channels, &addr) .await .map_err(|e| tonic::Status::internal(e.to_string()))?; let mut client = KvRawClient::new(channel); let req = RawScanRequest { start_key, end_key, limit, namespace, }; let resp = client.raw_scan(req).await?.into_inner(); let next = if resp.has_more { Some(resp.next_key) } else { None }; Ok((resp.keys, resp.values, next)) } }) .await } pub async fn cas( &mut self, key: Vec, value: Vec, expected_version: u64, ) -> Result<(bool, u64, u64), tonic::Status> { let route_key = key.clone(); let channels = Arc::clone(&self.channels); let namespace = self.namespace.clone(); self.with_routed_addr(&route_key, |addr| { let channels = Arc::clone(&channels); let key = key.clone(); let value = value.clone(); let namespace = namespace.clone(); async move { let channel = Self::get_channel_from_map(&channels, &addr) .await .map_err(|e| tonic::Status::internal(e.to_string()))?; let mut client = KvCasClient::new(channel); let req = CasRequest { key, value, expected_version, namespace, }; let resp = client.compare_and_swap(req).await?.into_inner(); Ok((resp.success, resp.current_version, resp.new_version)) } }) .await } pub async fn cas_get(&mut self, key: Vec) -> Result)>, tonic::Status> { let route_key = key.clone(); let channels = Arc::clone(&self.channels); let namespace = self.namespace.clone(); self.with_routed_addr(&route_key, |addr| { let channels = Arc::clone(&channels); let key = key.clone(); let namespace = namespace.clone(); async move { let channel = Self::get_channel_from_map(&channels, &addr) .await .map_err(|e| tonic::Status::internal(e.to_string()))?; let mut client = KvCasClient::new(channel); let req = GetRequest { key, namespace }; let resp = client.get(req).await?.into_inner(); if resp.found { Ok(Some((resp.version, resp.value))) } else { Ok(None) } } }) .await } pub async fn cas_scan( &mut self, start_key: Vec, end_key: Vec, limit: u32, ) -> Result<(Vec<(Vec, Vec, u64)>, Option>), tonic::Status> { let route_key = start_key.clone(); let channels = Arc::clone(&self.channels); let namespace = self.namespace.clone(); self.with_routed_addr(&route_key, |addr| { let channels = Arc::clone(&channels); let start_key = start_key.clone(); let end_key = end_key.clone(); let namespace = namespace.clone(); async move { let channel = Self::get_channel_from_map(&channels, &addr) .await .map_err(|e| tonic::Status::internal(e.to_string()))?; let mut client = KvCasClient::new(channel); let req = flaredb_proto::kvrpc::ScanRequest { start_key, end_key, limit, namespace, }; let resp = client.scan(req).await?.into_inner(); let entries: Vec<(Vec, Vec, u64)> = resp .entries .into_iter() .map(|kv| (kv.key, kv.value, kv.version)) .collect(); let next = if resp.has_more { Some(resp.next_key) } else { None }; Ok((entries, next)) } }) .await } pub async fn cas_delete( &mut self, key: Vec, expected_version: u64, ) -> Result<(bool, u64, bool), tonic::Status> { let route_key = key.clone(); let channels = Arc::clone(&self.channels); let namespace = self.namespace.clone(); self.with_routed_addr(&route_key, |addr| { let channels = Arc::clone(&channels); let key = key.clone(); let namespace = namespace.clone(); async move { let channel = Self::get_channel_from_map(&channels, &addr) .await .map_err(|e| tonic::Status::internal(e.to_string()))?; let mut client = KvCasClient::new(channel); let req = DeleteRequest { key, expected_version, namespace, }; let resp = client.delete(req).await?.into_inner(); Ok((resp.success, resp.current_version, resp.existed)) } }) .await } async fn resolve_addr_via_chainfire( &self, key: &[u8], mut kv_client: ChainfireKvClient, ) -> Result { let regions = list_chainfire_regions(&mut kv_client).await?; let stores = list_chainfire_stores(&mut kv_client).await?; let region = regions .into_iter() .find(|region| { let start_ok = region.start_key.is_empty() || key >= region.start_key.as_slice(); let end_ok = region.end_key.is_empty() || key < region.end_key.as_slice(); start_ok && end_ok }) .ok_or_else(|| tonic::Status::not_found("region not found"))?; let leader = stores .get(®ion.leader_id) .ok_or_else(|| tonic::Status::not_found("leader store not found"))?; self.region_cache .update( Region { id: region.id, start_key: region.start_key, end_key: region.end_key, peers: region.peers, leader_id: region.leader_id, }, Store { id: leader.id, addr: leader.addr.clone(), }, ) .await; Ok(leader.addr.clone()) } } fn transport_endpoint(addr: &str) -> String { if addr.starts_with("http://") || addr.starts_with("https://") { addr.to_string() } else { format!("http://{}", addr) } } fn prefix_range_end(prefix: &str) -> Vec { let mut end = prefix.as_bytes().to_vec(); if let Some(last) = end.last_mut() { *last = last.saturating_add(1); } end } async fn list_chainfire_stores( kv_client: &mut ChainfireKvClient, ) -> Result, tonic::Status> { const PREFIX: &str = "/flaredb/stores/"; let response = kv_client .range(ChainfireRangeRequest { key: PREFIX.as_bytes().to_vec(), range_end: prefix_range_end(PREFIX), limit: 0, revision: 0, keys_only: false, count_only: false, }) .await? .into_inner(); let mut stores = HashMap::new(); for kv in response.kvs { if let Ok(store) = serde_json::from_slice::(&kv.value) { stores.insert(store.id, store); } } Ok(stores) } async fn list_chainfire_regions( kv_client: &mut ChainfireKvClient, ) -> Result, tonic::Status> { const PREFIX: &str = "/flaredb/regions/"; let response = kv_client .range(ChainfireRangeRequest { key: PREFIX.as_bytes().to_vec(), range_end: prefix_range_end(PREFIX), limit: 0, revision: 0, keys_only: false, count_only: false, }) .await? .into_inner(); let mut regions = Vec::new(); for kv in response.kvs { if let Ok(region) = serde_json::from_slice::(&kv.value) { regions.push(region); } } Ok(regions) } #[cfg(test)] mod tests { use super::RdbClient; #[test] fn unknown_transport_errors_are_treated_as_retryable_routes() { let status = tonic::Status::unknown("transport error"); assert!(RdbClient::is_retryable_route_error(&status)); assert!(RdbClient::is_transport_error(&status)); } #[test] fn not_leader_errors_remain_retryable() { let status = tonic::Status::failed_precondition("NotLeader { leader_id: Some(1) }"); assert!(RdbClient::is_retryable_route_error(&status)); assert!(!RdbClient::is_transport_error(&status)); } }