//! Raft state machine implementation //! //! The state machine applies committed Raft log entries to the KV store. use crate::{KvStore, LeaseStore, RocksStore}; use chainfire_types::command::{Compare, CompareResult, CompareTarget, RaftCommand, RaftResponse}; use chainfire_types::error::StorageError; use chainfire_types::watch::WatchEvent; use chainfire_types::Revision; use std::sync::Arc; use tokio::sync::mpsc; use tracing::warn; /// State machine that applies Raft commands to the KV store pub struct StateMachine { /// Underlying KV store kv: KvStore, /// Lease store for TTL management leases: Arc, /// Channel to send watch events watch_tx: Option>, } impl StateMachine { /// Create a new state machine pub fn new(store: RocksStore) -> Result { let kv = KvStore::new(store)?; Ok(Self { kv, leases: Arc::new(LeaseStore::new()), watch_tx: None, }) } /// Set the watch event sender pub fn set_watch_sender(&mut self, tx: mpsc::UnboundedSender) { self.watch_tx = Some(tx); } /// Get the underlying KV store pub fn kv(&self) -> &KvStore { &self.kv } /// Get the lease store pub fn leases(&self) -> &Arc { &self.leases } /// Get current revision pub fn current_revision(&self) -> Revision { self.kv.current_revision() } /// Apply a Raft command and return the response pub fn apply(&self, command: RaftCommand) -> Result { match command { RaftCommand::Put { key, value, lease_id, prev_kv, } => self.apply_put(key, value, lease_id, prev_kv), RaftCommand::Delete { key, prev_kv } => self.apply_delete(key, prev_kv), RaftCommand::DeleteRange { start, end, prev_kv, } => self.apply_delete_range(start, end, prev_kv), RaftCommand::Txn { compare, success, failure, } => self.apply_txn(compare, success, failure), RaftCommand::LeaseGrant { id, ttl } => self.apply_lease_grant(id, ttl), RaftCommand::LeaseRevoke { id } => self.apply_lease_revoke(id), RaftCommand::LeaseRefresh { id } => self.apply_lease_refresh(id), RaftCommand::Noop => Ok(RaftResponse::new(self.current_revision())), } } /// Apply a Put command fn apply_put( &self, key: Vec, value: Vec, lease_id: Option, return_prev: bool, ) -> Result { // If key previously had a lease, detach it if let Some(ref prev_entry) = self.kv.get(&key)? { if let Some(old_lease_id) = prev_entry.lease_id { self.leases.detach_key(old_lease_id, &key); } } let (revision, prev) = self.kv.put(key.clone(), value.clone(), lease_id)?; // Attach key to new lease if specified if let Some(lid) = lease_id { if let Err(e) = self.leases.attach_key(lid, key.clone()) { warn!("Failed to attach key to lease {}: {}", lid, e); } } // Emit watch event if let Some(tx) = &self.watch_tx { let entry = self.kv.get(&key)?.unwrap(); let event = WatchEvent::put(entry, if return_prev { prev.clone() } else { None }); if tx.send(event).is_err() { warn!("Watch event channel closed"); } } Ok(RaftResponse::with_prev_kv( revision, if return_prev { prev } else { None }, )) } /// Apply a Delete command fn apply_delete(&self, key: Vec, return_prev: bool) -> Result { // Detach from lease if attached if let Some(ref entry) = self.kv.get(&key)? { if let Some(lease_id) = entry.lease_id { self.leases.detach_key(lease_id, &key); } } let (revision, prev) = self.kv.delete(&key)?; // Emit watch event if key existed if let (Some(tx), Some(ref deleted)) = (&self.watch_tx, &prev) { let event = WatchEvent::delete( deleted.clone(), if return_prev { prev.clone() } else { None }, ); if tx.send(event).is_err() { warn!("Watch event channel closed"); } } let deleted = if prev.is_some() { 1 } else { 0 }; Ok(RaftResponse { revision, prev_kv: if return_prev { prev } else { None }, deleted, ..Default::default() }) } /// Apply a DeleteRange command fn apply_delete_range( &self, start: Vec, end: Vec, return_prev: bool, ) -> Result { let (revision, deleted_entries) = self.kv.delete_range(&start, &end)?; // Emit watch events for each deleted key if let Some(tx) = &self.watch_tx { for entry in &deleted_entries { let event = WatchEvent::delete(entry.clone(), None); if tx.send(event).is_err() { warn!("Watch event channel closed"); break; } } } Ok(RaftResponse::deleted( revision, deleted_entries.len() as u64, if return_prev { deleted_entries } else { vec![] }, )) } /// Apply a transaction fn apply_txn( &self, compare: Vec, success: Vec, failure: Vec, ) -> Result { use chainfire_types::command::TxnOpResponse; // Evaluate all comparisons let all_match = compare.iter().all(|c| self.evaluate_compare(c)); let ops = if all_match { &success } else { &failure }; // Apply operations and collect responses let mut txn_responses = Vec::with_capacity(ops.len()); for op in ops { match op { chainfire_types::command::TxnOp::Put { key, value, lease_id, } => { let resp = self.apply_put(key.clone(), value.clone(), *lease_id, true)?; txn_responses.push(TxnOpResponse::Put { prev_kv: resp.prev_kv, }); } chainfire_types::command::TxnOp::Delete { key } => { let resp = self.apply_delete(key.clone(), true)?; txn_responses.push(TxnOpResponse::Delete { deleted: resp.deleted, prev_kvs: resp.prev_kvs, }); } chainfire_types::command::TxnOp::DeleteRange { start, end } => { let resp = self.apply_delete_range(start.clone(), end.clone(), true)?; txn_responses.push(TxnOpResponse::Delete { deleted: resp.deleted, prev_kvs: resp.prev_kvs, }); } chainfire_types::command::TxnOp::Range { key, range_end, limit, keys_only, count_only, } => { // Range operations are read-only - perform the read here let entries = if range_end.is_empty() { // Single key lookup match self.kv.get(key)? { Some(entry) => vec![entry], None => vec![], } } else { // Range query let end_opt = if range_end.is_empty() { None } else { Some(range_end.as_slice()) }; let mut results = self.kv.range(key, end_opt)?; // Apply limit if *limit > 0 { results.truncate(*limit as usize); } results }; let count = entries.len() as u64; let kvs = if *count_only { vec![] } else if *keys_only { entries .into_iter() .map(|e| chainfire_types::kv::KvEntry { key: e.key, value: vec![], version: e.version, create_revision: e.create_revision, mod_revision: e.mod_revision, lease_id: e.lease_id, }) .collect() } else { entries }; txn_responses.push(TxnOpResponse::Range { kvs, count, more: false, // TODO: handle pagination }); } } } Ok(RaftResponse::txn( self.current_revision(), all_match, txn_responses, )) } /// Evaluate a single comparison fn evaluate_compare(&self, compare: &Compare) -> bool { let entry = match self.kv.get(&compare.key) { Ok(Some(e)) => e, Ok(None) => { // Key doesn't exist - special handling return match &compare.target { CompareTarget::Version(v) => match compare.result { CompareResult::Equal => *v == 0, CompareResult::NotEqual => *v != 0, CompareResult::Greater => false, CompareResult::Less => *v > 0, }, _ => false, }; } Err(_) => return false, }; match &compare.target { CompareTarget::Version(expected) => { self.compare_values(entry.version, *expected, compare.result) } CompareTarget::CreateRevision(expected) => { self.compare_values(entry.create_revision, *expected, compare.result) } CompareTarget::ModRevision(expected) => { self.compare_values(entry.mod_revision, *expected, compare.result) } CompareTarget::Value(expected) => match compare.result { CompareResult::Equal => entry.value == *expected, CompareResult::NotEqual => entry.value != *expected, CompareResult::Greater => entry.value.as_slice() > expected.as_slice(), CompareResult::Less => entry.value.as_slice() < expected.as_slice(), }, } } /// Compare two numeric values fn compare_values(&self, actual: u64, expected: u64, result: CompareResult) -> bool { match result { CompareResult::Equal => actual == expected, CompareResult::NotEqual => actual != expected, CompareResult::Greater => actual > expected, CompareResult::Less => actual < expected, } } /// Apply a lease grant command fn apply_lease_grant(&self, id: i64, ttl: i64) -> Result { let lease = self.leases.grant(id, ttl)?; Ok(RaftResponse::lease(self.current_revision(), lease.id, lease.ttl)) } /// Apply a lease revoke command fn apply_lease_revoke(&self, id: i64) -> Result { let keys = self.leases.revoke(id)?; // Delete all keys attached to the lease let mut deleted = 0u64; for key in keys { let (_, prev) = self.kv.delete(&key)?; if prev.is_some() { deleted += 1; // Emit watch event if let (Some(tx), Some(ref entry)) = (&self.watch_tx, &prev) { let event = WatchEvent::delete(entry.clone(), None); if tx.send(event).is_err() { warn!("Watch event channel closed"); } } } } Ok(RaftResponse { revision: self.current_revision(), deleted, ..Default::default() }) } /// Apply a lease refresh command fn apply_lease_refresh(&self, id: i64) -> Result { let ttl = self.leases.refresh(id)?; Ok(RaftResponse::lease(self.current_revision(), id, ttl)) } /// Delete keys by lease ID (called when lease expires) pub fn delete_keys_by_lease(&self, lease_id: i64) -> Result { if let Some(lease) = self.leases.get(lease_id) { let keys = lease.keys.clone(); // Revoke will also return the keys, but we already have them let _ = self.leases.revoke(lease_id); let mut deleted = 0u64; for key in keys { let (_, prev) = self.kv.delete(&key)?; if prev.is_some() { deleted += 1; // Emit watch event if let (Some(tx), Some(ref entry)) = (&self.watch_tx, &prev) { let event = WatchEvent::delete(entry.clone(), None); if tx.send(event).is_err() { warn!("Watch event channel closed"); } } } } Ok(deleted) } else { Ok(0) } } } #[cfg(test)] mod tests { use super::*; use tempfile::tempdir; fn create_test_state_machine() -> StateMachine { let dir = tempdir().unwrap(); let store = RocksStore::new(dir.path()).unwrap(); StateMachine::new(store).unwrap() } #[test] fn test_apply_put() { let sm = create_test_state_machine(); let cmd = RaftCommand::Put { key: b"key1".to_vec(), value: b"value1".to_vec(), lease_id: None, prev_kv: false, }; let response = sm.apply(cmd).unwrap(); assert_eq!(response.revision, 1); assert!(response.prev_kv.is_none()); let entry = sm.kv().get(b"key1").unwrap().unwrap(); assert_eq!(entry.value, b"value1"); } #[test] fn test_apply_put_with_prev() { let sm = create_test_state_machine(); sm.apply(RaftCommand::Put { key: b"key1".to_vec(), value: b"value1".to_vec(), lease_id: None, prev_kv: false, }) .unwrap(); let response = sm .apply(RaftCommand::Put { key: b"key1".to_vec(), value: b"value2".to_vec(), lease_id: None, prev_kv: true, }) .unwrap(); assert_eq!(response.revision, 2); assert!(response.prev_kv.is_some()); assert_eq!(response.prev_kv.unwrap().value, b"value1"); } #[test] fn test_apply_delete() { let sm = create_test_state_machine(); sm.apply(RaftCommand::Put { key: b"key1".to_vec(), value: b"value1".to_vec(), lease_id: None, prev_kv: false, }) .unwrap(); let response = sm .apply(RaftCommand::Delete { key: b"key1".to_vec(), prev_kv: true, }) .unwrap(); assert_eq!(response.deleted, 1); assert!(response.prev_kv.is_some()); assert!(sm.kv().get(b"key1").unwrap().is_none()); } #[test] fn test_apply_txn_success() { let sm = create_test_state_machine(); // Create initial key sm.apply(RaftCommand::Put { key: b"counter".to_vec(), value: b"1".to_vec(), lease_id: None, prev_kv: false, }) .unwrap(); // Transaction: if version == 1, increment let cmd = RaftCommand::Txn { compare: vec![Compare { key: b"counter".to_vec(), target: CompareTarget::Version(1), result: CompareResult::Equal, }], success: vec![chainfire_types::command::TxnOp::Put { key: b"counter".to_vec(), value: b"2".to_vec(), lease_id: None, }], failure: vec![], }; let response = sm.apply(cmd).unwrap(); assert!(response.succeeded); let entry = sm.kv().get(b"counter").unwrap().unwrap(); assert_eq!(entry.value, b"2"); } #[test] fn test_apply_txn_failure() { let sm = create_test_state_machine(); // Create initial key sm.apply(RaftCommand::Put { key: b"counter".to_vec(), value: b"1".to_vec(), lease_id: None, prev_kv: false, }) .unwrap(); // Transaction: if version == 5, increment (should fail) let cmd = RaftCommand::Txn { compare: vec![Compare { key: b"counter".to_vec(), target: CompareTarget::Version(5), result: CompareResult::Equal, }], success: vec![chainfire_types::command::TxnOp::Put { key: b"counter".to_vec(), value: b"2".to_vec(), lease_id: None, }], failure: vec![chainfire_types::command::TxnOp::Put { key: b"counter".to_vec(), value: b"failed".to_vec(), lease_id: None, }], }; let response = sm.apply(cmd).unwrap(); assert!(!response.succeeded); let entry = sm.kv().get(b"counter").unwrap().unwrap(); assert_eq!(entry.value, b"failed"); } #[tokio::test] async fn test_watch_events() { let mut sm = create_test_state_machine(); let (tx, mut rx) = mpsc::unbounded_channel(); sm.set_watch_sender(tx); // Apply a put sm.apply(RaftCommand::Put { key: b"key1".to_vec(), value: b"value1".to_vec(), lease_id: None, prev_kv: false, }) .unwrap(); // Check event was sent let event = rx.recv().await.unwrap(); assert!(event.is_put()); assert_eq!(event.kv.key, b"key1"); assert_eq!(event.kv.value, b"value1"); } }