//! Network implementation for Raft RPC //! //! This module provides network adapters for OpenRaft to communicate between nodes. use crate::config::TypeConfig; use chainfire_types::NodeId; use openraft::error::{InstallSnapshotError, NetworkError, RaftError, RPCError, StreamingError, Fatal}; use openraft::network::{RPCOption, RaftNetwork, RaftNetworkFactory}; use openraft::raft::{ AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest, InstallSnapshotResponse, SnapshotResponse, VoteRequest, VoteResponse, }; use openraft::BasicNode; use std::collections::HashMap; use std::sync::Arc; use thiserror::Error; use tokio::sync::RwLock; use tracing::{debug, trace}; /// Network error type #[derive(Error, Debug)] pub enum RaftNetworkError { #[error("Connection failed to node {node_id}: {reason}")] ConnectionFailed { node_id: NodeId, reason: String }, #[error("RPC failed: {0}")] RpcFailed(String), #[error("Timeout")] Timeout, #[error("Node {0} not found")] NodeNotFound(NodeId), } /// Trait for sending Raft RPCs /// This will be implemented by the gRPC client in chainfire-api #[async_trait::async_trait] pub trait RaftRpcClient: Send + Sync + 'static { async fn vote( &self, target: NodeId, req: VoteRequest, ) -> Result, RaftNetworkError>; async fn append_entries( &self, target: NodeId, req: AppendEntriesRequest, ) -> Result, RaftNetworkError>; async fn install_snapshot( &self, target: NodeId, req: InstallSnapshotRequest, ) -> Result, RaftNetworkError>; } /// Factory for creating network connections to Raft peers pub struct NetworkFactory { /// RPC client for sending requests client: Arc, /// Node address mapping nodes: Arc>>, } impl NetworkFactory { /// Create a new network factory pub fn new(client: Arc) -> Self { Self { client, nodes: Arc::new(RwLock::new(HashMap::new())), } } /// Add or update a node's address pub async fn add_node(&self, id: NodeId, node: BasicNode) { let mut nodes = self.nodes.write().await; nodes.insert(id, node); } /// Remove a node pub async fn remove_node(&self, id: NodeId) { let mut nodes = self.nodes.write().await; nodes.remove(&id); } } impl RaftNetworkFactory for NetworkFactory { type Network = NetworkConnection; async fn new_client(&mut self, target: NodeId, node: &BasicNode) -> Self::Network { // Update our node map self.nodes.write().await.insert(target, node.clone()); NetworkConnection { target, node: node.clone(), client: Arc::clone(&self.client), } } } /// A connection to a single Raft peer pub struct NetworkConnection { target: NodeId, node: BasicNode, client: Arc, } /// Convert our network error to OpenRaft's RPCError fn to_rpc_error(e: RaftNetworkError) -> RPCError> { RPCError::Network(NetworkError::new(&e)) } /// Convert our network error to OpenRaft's RPCError with InstallSnapshotError fn to_snapshot_rpc_error(e: RaftNetworkError) -> RPCError> { RPCError::Network(NetworkError::new(&e)) } impl RaftNetwork for NetworkConnection { async fn vote( &mut self, req: VoteRequest, _option: RPCOption, ) -> Result< VoteResponse, RPCError>, > { trace!(target = self.target, "Sending vote request"); self.client .vote(self.target, req) .await .map_err(to_rpc_error) } async fn append_entries( &mut self, req: AppendEntriesRequest, _option: RPCOption, ) -> Result< AppendEntriesResponse, RPCError>, > { trace!( target = self.target, entries = req.entries.len(), "Sending append entries" ); self.client .append_entries(self.target, req) .await .map_err(to_rpc_error) } async fn install_snapshot( &mut self, req: InstallSnapshotRequest, _option: RPCOption, ) -> Result< InstallSnapshotResponse, RPCError>, > { debug!( target = self.target, last_log_id = ?req.meta.last_log_id, "Sending install snapshot" ); self.client .install_snapshot(self.target, req) .await .map_err(to_snapshot_rpc_error) } async fn full_snapshot( &mut self, vote: openraft::Vote, snapshot: openraft::Snapshot, _cancel: impl std::future::Future + Send + 'static, _option: RPCOption, ) -> Result< SnapshotResponse, StreamingError>, > { // For simplicity, send snapshot in one chunk // In production, you'd want to chunk large snapshots let req = InstallSnapshotRequest { vote, meta: snapshot.meta.clone(), offset: 0, data: snapshot.snapshot.into_inner(), done: true, }; debug!( target = self.target, last_log_id = ?snapshot.meta.last_log_id, "Sending full snapshot" ); let resp = self .client .install_snapshot(self.target, req) .await .map_err(|e| StreamingError::Network(NetworkError::new(&e)))?; Ok(SnapshotResponse { vote: resp.vote }) } } /// In-memory RPC client for testing #[cfg(test)] pub mod test_client { use super::*; use std::collections::HashMap; use tokio::sync::mpsc; /// A simple in-memory RPC client for testing pub struct InMemoryRpcClient { /// Channel senders to each node channels: Arc>>>, } pub enum RpcMessage { Vote( VoteRequest, tokio::sync::oneshot::Sender>, ), AppendEntries( AppendEntriesRequest, tokio::sync::oneshot::Sender>, ), InstallSnapshot( InstallSnapshotRequest, tokio::sync::oneshot::Sender>, ), } impl InMemoryRpcClient { pub fn new() -> Self { Self { channels: Arc::new(RwLock::new(HashMap::new())), } } pub async fn register(&self, id: NodeId, tx: mpsc::Sender) { self.channels.write().await.insert(id, tx); } } #[async_trait::async_trait] impl RaftRpcClient for InMemoryRpcClient { async fn vote( &self, target: NodeId, req: VoteRequest, ) -> Result, RaftNetworkError> { let channels = self.channels.read().await; let tx = channels .get(&target) .ok_or(RaftNetworkError::NodeNotFound(target))?; let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); tx.send(RpcMessage::Vote(req, resp_tx)) .await .map_err(|_| RaftNetworkError::RpcFailed("Channel closed".into()))?; resp_rx .await .map_err(|_| RaftNetworkError::RpcFailed("Response channel closed".into())) } async fn append_entries( &self, target: NodeId, req: AppendEntriesRequest, ) -> Result, RaftNetworkError> { let channels = self.channels.read().await; let tx = channels .get(&target) .ok_or(RaftNetworkError::NodeNotFound(target))?; let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); tx.send(RpcMessage::AppendEntries(req, resp_tx)) .await .map_err(|_| RaftNetworkError::RpcFailed("Channel closed".into()))?; resp_rx .await .map_err(|_| RaftNetworkError::RpcFailed("Response channel closed".into())) } async fn install_snapshot( &self, target: NodeId, req: InstallSnapshotRequest, ) -> Result, RaftNetworkError> { let channels = self.channels.read().await; let tx = channels .get(&target) .ok_or(RaftNetworkError::NodeNotFound(target))?; let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); tx.send(RpcMessage::InstallSnapshot(req, resp_tx)) .await .map_err(|_| RaftNetworkError::RpcFailed("Channel closed".into()))?; resp_rx .await .map_err(|_| RaftNetworkError::RpcFailed("Response channel closed".into())) } } }