//! Watch subscription registry use crate::matcher::KeyMatcher; use crate::next_watch_id; use chainfire_types::watch::{WatchEvent, WatchRequest, WatchResponse}; use chainfire_types::Revision; use dashmap::DashMap; use parking_lot::RwLock; use std::collections::{BTreeMap, HashSet}; use tokio::sync::mpsc; use tracing::{debug, trace, warn}; /// A registered watch subscription struct WatchSubscription { watch_id: i64, matcher: KeyMatcher, prev_kv: bool, created_revision: Revision, sender: mpsc::Sender, } /// Registry for all active watch subscriptions pub struct WatchRegistry { /// Map of watch_id -> subscription watches: DashMap, /// Index: key prefix -> watch_ids for efficient dispatch /// Uses BTreeMap for prefix range queries prefix_index: RwLock, HashSet>>, /// Current revision for progress notifications current_revision: RwLock, } impl WatchRegistry { /// Create a new watch registry pub fn new() -> Self { Self { watches: DashMap::new(), prefix_index: RwLock::new(BTreeMap::new()), current_revision: RwLock::new(0), } } /// Update current revision pub fn set_revision(&self, revision: Revision) { *self.current_revision.write() = revision; } /// Get current revision pub fn current_revision(&self) -> Revision { *self.current_revision.read() } /// Create a new watch subscription pub fn create_watch( &self, req: WatchRequest, sender: mpsc::Sender, ) -> i64 { let watch_id = if req.watch_id != 0 { req.watch_id } else { next_watch_id() }; let matcher = if let Some(ref end) = req.range_end { KeyMatcher::range(req.key.clone(), end.clone()) } else { KeyMatcher::key(req.key.clone()) }; let subscription = WatchSubscription { watch_id, matcher, prev_kv: req.prev_kv, created_revision: req.start_revision.unwrap_or_else(|| self.current_revision()), sender, }; // Add to watches self.watches.insert(watch_id, subscription); // Add to prefix index { let mut index = self.prefix_index.write(); index .entry(req.key.clone()) .or_default() .insert(watch_id); } debug!(watch_id, key = ?String::from_utf8_lossy(&req.key), "Created watch"); watch_id } /// Cancel a watch pub fn cancel_watch(&self, watch_id: i64) -> bool { if let Some((_, sub)) = self.watches.remove(&watch_id) { // Remove from prefix index let mut index = self.prefix_index.write(); if let Some(ids) = index.get_mut(sub.matcher.start_key()) { ids.remove(&watch_id); if ids.is_empty() { index.remove(sub.matcher.start_key()); } } debug!(watch_id, "Canceled watch"); true } else { false } } /// Get watch count pub fn watch_count(&self) -> usize { self.watches.len() } /// Dispatch an event to matching watches pub async fn dispatch_event(&self, event: WatchEvent) { let key = &event.kv.key; let revision = event.kv.mod_revision; // Update current revision { let mut current = self.current_revision.write(); if revision > *current { *current = revision; } } // Find all matching watches let matching_ids = self.find_matching_watches(key); trace!( key = ?String::from_utf8_lossy(key), matches = matching_ids.len(), "Dispatching event" ); for watch_id in matching_ids { if let Some(sub) = self.watches.get(&watch_id) { // Check if event revision is after watch creation if revision > sub.created_revision { let response = WatchResponse::events( watch_id, vec![if sub.prev_kv { event.clone() } else { WatchEvent { event_type: event.event_type, kv: event.kv.clone(), prev_kv: None, } }], ); // Non-blocking send if sub.sender.try_send(response).is_err() { warn!(watch_id, "Watch channel full or closed"); } } } } } /// Find watches that match a key fn find_matching_watches(&self, key: &[u8]) -> Vec { let mut result = Vec::new(); // Check each subscription for match // This is O(n) but can be optimized with better indexing for entry in self.watches.iter() { if entry.matcher.matches(key) { result.push(*entry.key()); } } result } /// Send progress notification to all watches pub async fn send_progress(&self) { let revision = self.current_revision(); for entry in self.watches.iter() { let response = WatchResponse::progress(entry.watch_id, revision); if entry.sender.try_send(response).is_err() { trace!(watch_id = entry.watch_id, "Progress notification dropped"); } } } /// Remove watches with closed channels pub fn cleanup_closed(&self) { let closed_ids: Vec = self .watches .iter() .filter(|entry| entry.sender.is_closed()) .map(|entry| *entry.key()) .collect(); for id in closed_ids { self.cancel_watch(id); } } } impl Default for WatchRegistry { fn default() -> Self { Self::new() } } #[cfg(test)] mod tests { use super::*; use chainfire_types::kv::KvEntry; use chainfire_types::watch::WatchEventType; fn create_test_event(key: &[u8], value: &[u8], revision: u64) -> WatchEvent { WatchEvent { event_type: WatchEventType::Put, kv: KvEntry::new(key.to_vec(), value.to_vec(), revision), prev_kv: None, } } #[tokio::test] async fn test_create_and_cancel_watch() { let registry = WatchRegistry::new(); let (tx, _rx) = mpsc::channel(10); let req = WatchRequest::key(1, b"/test/key"); let watch_id = registry.create_watch(req, tx); assert_eq!(watch_id, 1); assert_eq!(registry.watch_count(), 1); assert!(registry.cancel_watch(watch_id)); assert_eq!(registry.watch_count(), 0); } #[tokio::test] async fn test_dispatch_to_single_key_watch() { let registry = WatchRegistry::new(); let (tx, mut rx) = mpsc::channel(10); let req = WatchRequest::key(1, b"/test/key"); registry.create_watch(req, tx); // Dispatch matching event let event = create_test_event(b"/test/key", b"value", 1); registry.dispatch_event(event).await; // Should receive event let response = rx.try_recv().unwrap(); assert_eq!(response.watch_id, 1); assert_eq!(response.events.len(), 1); assert_eq!(response.events[0].kv.key, b"/test/key"); } #[tokio::test] async fn test_dispatch_to_prefix_watch() { let registry = WatchRegistry::new(); let (tx, mut rx) = mpsc::channel(10); let req = WatchRequest::prefix(1, b"/nodes/"); registry.create_watch(req, tx); // Dispatch matching events registry .dispatch_event(create_test_event(b"/nodes/1", b"data1", 1)) .await; registry .dispatch_event(create_test_event(b"/nodes/2", b"data2", 2)) .await; registry .dispatch_event(create_test_event(b"/tasks/1", b"other", 3)) .await; // Should receive 2 events (not /tasks/1) let resp1 = rx.try_recv().unwrap(); let resp2 = rx.try_recv().unwrap(); assert!(rx.try_recv().is_err()); assert_eq!(resp1.events[0].kv.key, b"/nodes/1"); assert_eq!(resp2.events[0].kv.key, b"/nodes/2"); } #[tokio::test] async fn test_revision_filtering() { let registry = WatchRegistry::new(); registry.set_revision(5); let (tx, mut rx) = mpsc::channel(10); // Watch starting from revision 10 let req = WatchRequest::key(1, b"/key").from_revision(10); registry.create_watch(req, tx); // Event at revision 8 (before watch start) registry .dispatch_event(create_test_event(b"/key", b"old", 8)) .await; // Event at revision 12 (after watch start) registry .dispatch_event(create_test_event(b"/key", b"new", 12)) .await; // Should only receive the second event let response = rx.try_recv().unwrap(); assert_eq!(response.events[0].kv.mod_revision, 12); assert!(rx.try_recv().is_err()); } #[tokio::test] async fn test_multiple_watches() { let registry = WatchRegistry::new(); let (tx1, mut rx1) = mpsc::channel(10); let (tx2, mut rx2) = mpsc::channel(10); registry.create_watch(WatchRequest::prefix(1, b"/a/"), tx1); registry.create_watch(WatchRequest::prefix(2, b"/a/b/"), tx2); // Event matching both watches registry .dispatch_event(create_test_event(b"/a/b/c", b"value", 1)) .await; // Both should receive the event assert!(rx1.try_recv().is_ok()); assert!(rx2.try_recv().is_ok()); } #[tokio::test] async fn test_cleanup_closed() { let registry = WatchRegistry::new(); let (tx, rx) = mpsc::channel(10); registry.create_watch(WatchRequest::key(1, b"/test"), tx); assert_eq!(registry.watch_count(), 1); // Drop the receiver to close the channel drop(rx); // Cleanup should remove the watch registry.cleanup_closed(); assert_eq!(registry.watch_count(), 0); } }