//! gRPC client for Raft RPC //! //! This module provides a gRPC-based implementation of RaftRpcClient //! for node-to-node Raft communication with retry and backoff support. use crate::internal_proto::{ raft_service_client::RaftServiceClient, AppendEntriesRequest as ProtoAppendEntriesRequest, LogEntry as ProtoLogEntry, VoteRequest as ProtoVoteRequest, }; use chainfire_raft::network::{RaftNetworkError, RaftRpcClient}; use chainfire_types::NodeId; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use tokio::sync::RwLock; use tonic::transport::Channel; use tracing::{debug, trace, warn}; // Custom Raft imports use chainfire_raft::core::{ AppendEntriesRequest, AppendEntriesResponse, VoteRequest, VoteResponse, }; /// Configuration for RPC retry behavior with exponential backoff. #[derive(Debug, Clone)] pub struct RetryConfig { /// Initial timeout for RPC calls (default: 500ms) pub initial_timeout: Duration, /// Maximum timeout after backoff (default: 30s) pub max_timeout: Duration, /// Maximum number of retry attempts (default: 3) pub max_retries: u32, /// Backoff multiplier between retries (default: 2.0) pub backoff_multiplier: f64, } impl Default for RetryConfig { fn default() -> Self { Self { initial_timeout: Duration::from_millis(500), max_timeout: Duration::from_secs(30), max_retries: 3, backoff_multiplier: 2.0, } } } impl RetryConfig { /// Create a new RetryConfig with custom values pub fn new( initial_timeout: Duration, max_timeout: Duration, max_retries: u32, backoff_multiplier: f64, ) -> Self { Self { initial_timeout, max_timeout, max_retries, backoff_multiplier, } } /// Calculate timeout for a given retry attempt (0-indexed) fn timeout_for_attempt(&self, attempt: u32) -> Duration { let multiplier = self.backoff_multiplier.powi(attempt as i32); let timeout_millis = (self.initial_timeout.as_millis() as f64 * multiplier) as u64; let timeout = Duration::from_millis(timeout_millis); timeout.min(self.max_timeout) } } /// gRPC-based Raft RPC client with retry support pub struct GrpcRaftClient { /// Cached gRPC clients per node clients: Arc>>>, /// Node address mapping node_addrs: Arc>>, /// Retry configuration retry_config: RetryConfig, } impl GrpcRaftClient { /// Create a new gRPC Raft client with default retry config pub fn new() -> Self { Self { clients: Arc::new(RwLock::new(HashMap::new())), node_addrs: Arc::new(RwLock::new(HashMap::new())), retry_config: RetryConfig::default(), } } /// Create a new gRPC Raft client with custom retry config pub fn new_with_retry(retry_config: RetryConfig) -> Self { Self { clients: Arc::new(RwLock::new(HashMap::new())), node_addrs: Arc::new(RwLock::new(HashMap::new())), retry_config, } } /// Add or update a node's address pub async fn add_node(&self, id: NodeId, addr: String) { debug!(node_id = id, addr = %addr, "Adding node address"); self.node_addrs.write().await.insert(id, addr); } /// Remove a node pub async fn remove_node(&self, id: NodeId) { self.node_addrs.write().await.remove(&id); self.clients.write().await.remove(&id); } /// Get or create a gRPC client for the target node async fn get_client(&self, target: NodeId) -> Result, RaftNetworkError> { // Check cache first { let clients = self.clients.read().await; if let Some(client) = clients.get(&target) { return Ok(client.clone()); } } // Get address let addr = { let addrs = self.node_addrs.read().await; addrs.get(&target).cloned() }; let addr = addr.ok_or(RaftNetworkError::NodeNotFound(target))?; // Create new connection let endpoint = format!("http://{}", addr); trace!(target = target, endpoint = %endpoint, "Connecting to node"); let channel = Channel::from_shared(endpoint.clone()) .map_err(|e| RaftNetworkError::ConnectionFailed { node_id: target, reason: e.to_string(), })? .connect() .await .map_err(|e| RaftNetworkError::ConnectionFailed { node_id: target, reason: e.to_string(), })?; let client = RaftServiceClient::new(channel); // Cache the client self.clients.write().await.insert(target, client.clone()); Ok(client) } /// Invalidate cached client for a node (e.g., on connection failure) async fn invalidate_client(&self, target: NodeId) { self.clients.write().await.remove(&target); } /// Execute an async operation with retry and exponential backoff async fn with_retry( &self, target: NodeId, rpc_name: &str, mut operation: F, ) -> Result where F: FnMut() -> Fut, Fut: std::future::Future>, { let mut last_error = None; for attempt in 0..=self.retry_config.max_retries { let timeout = self.retry_config.timeout_for_attempt(attempt); trace!( target = target, rpc = rpc_name, attempt = attempt, timeout_ms = timeout.as_millis(), "Attempting RPC" ); match tokio::time::timeout(timeout, operation()).await { Ok(Ok(result)) => return Ok(result), Ok(Err(e)) => { warn!( target = target, rpc = rpc_name, attempt = attempt, error = %e, "RPC failed" ); // Invalidate cached client on failure self.invalidate_client(target).await; last_error = Some(e); } Err(_) => { warn!( target = target, rpc = rpc_name, attempt = attempt, timeout_ms = timeout.as_millis(), "RPC timed out" ); // Invalidate cached client on timeout self.invalidate_client(target).await; last_error = Some(RaftNetworkError::RpcFailed(format!( "{} timed out after {}ms", rpc_name, timeout.as_millis() ))); } } // Wait before retry (backoff delay) if attempt < self.retry_config.max_retries { let backoff_delay = self.retry_config.timeout_for_attempt(attempt); tokio::time::sleep(backoff_delay).await; } } Err(last_error.unwrap_or_else(|| { RaftNetworkError::RpcFailed(format!( "{} failed after {} retries", rpc_name, self.retry_config.max_retries )) })) } } impl Default for GrpcRaftClient { fn default() -> Self { Self::new() } } #[async_trait::async_trait] impl RaftRpcClient for GrpcRaftClient { async fn vote( &self, target: NodeId, req: VoteRequest, ) -> Result { trace!(target = target, term = req.term, "Sending vote request"); self.with_retry(target, "vote", || async { let mut client = self.get_client(target).await?; // Convert to proto request let proto_req = ProtoVoteRequest { term: req.term, candidate_id: req.candidate_id, last_log_index: req.last_log_index, last_log_term: req.last_log_term, }; let response = client .vote(proto_req) .await .map_err(|e| RaftNetworkError::RpcFailed(e.to_string()))?; let resp = response.into_inner(); Ok(VoteResponse { term: resp.term, vote_granted: resp.vote_granted, }) }) .await } async fn append_entries( &self, target: NodeId, req: AppendEntriesRequest, ) -> Result { trace!( target = target, entries = req.entries.len(), "Sending append entries" ); // Clone entries once for potential retries let entries_data: Vec<(u64, u64, Vec)> = req .entries .iter() .map(|e| { use chainfire_storage::EntryPayload; let data = match &e.payload { EntryPayload::Blank => vec![], EntryPayload::Normal(cmd) => { bincode::serialize(cmd).unwrap_or_default() } EntryPayload::Membership(_) => vec![], }; (e.log_id.index, e.log_id.term, data) }) .collect(); let term = req.term; let leader_id = req.leader_id; let prev_log_index = req.prev_log_index; let prev_log_term = req.prev_log_term; let leader_commit = req.leader_commit; self.with_retry(target, "append_entries", || { let entries_data = entries_data.clone(); async move { let mut client = self.get_client(target).await?; let entries: Vec = entries_data .into_iter() .map(|(index, term, data)| ProtoLogEntry { index, term, data }) .collect(); let proto_req = ProtoAppendEntriesRequest { term, leader_id, prev_log_index, prev_log_term, entries, leader_commit, }; let response = client .append_entries(proto_req) .await .map_err(|e| RaftNetworkError::RpcFailed(e.to_string()))?; let resp = response.into_inner(); Ok(AppendEntriesResponse { term: resp.term, success: resp.success, conflict_index: if resp.conflict_index > 0 { Some(resp.conflict_index) } else { None }, conflict_term: if resp.conflict_term > 0 { Some(resp.conflict_term) } else { None }, }) } }) .await } }