//! Policy cache for authorization //! //! Caches policy bindings and roles to reduce storage lookups. //! Supports both TTL-based expiration and event-driven invalidation. use std::sync::Arc; use std::time::{Duration, Instant}; use dashmap::DashMap; use tokio::sync::broadcast; use iam_types::{PolicyBinding, PrincipalRef, Role, Scope}; /// Cache invalidation event #[derive(Debug, Clone)] pub enum CacheInvalidation { /// Invalidate bindings for a specific principal Principal(String), /// Invalidate a specific role Role(String), /// Invalidate all bindings within a scope Scope(Scope), /// Invalidate all cached data All, } /// Source of cache invalidation events (e.g., from Chainfire watch) pub trait CacheInvalidationSource: Send + Sync { /// Subscribe to invalidation events fn subscribe(&self) -> broadcast::Receiver; } /// Configuration for the policy cache #[derive(Debug, Clone)] pub struct PolicyCacheConfig { /// TTL for cached bindings pub binding_ttl: Duration, /// TTL for cached roles pub role_ttl: Duration, /// Maximum number of binding entries pub max_binding_entries: usize, /// Maximum number of role entries pub max_role_entries: usize, } impl Default for PolicyCacheConfig { fn default() -> Self { Self { binding_ttl: Duration::from_secs(300), // 5 minutes role_ttl: Duration::from_secs(600), // 10 minutes max_binding_entries: 10000, max_role_entries: 1000, } } } /// Cached bindings for a principal struct CachedBindings { bindings: Vec, fetched_at: Instant, /// Scopes covered by these bindings (for scope-based invalidation) scopes: Vec, } /// Cached role struct CachedRole { role: Role, fetched_at: Instant, } /// Policy cache pub struct PolicyCache { config: PolicyCacheConfig, /// Bindings cache keyed by principal reference bindings: DashMap, /// Roles cache keyed by role name roles: DashMap, } impl PolicyCache { /// Create a new policy cache pub fn new(config: PolicyCacheConfig) -> Self { Self { config, bindings: DashMap::new(), roles: DashMap::new(), } } /// Create with default config pub fn default_config() -> Self { Self::new(PolicyCacheConfig::default()) } /// Get cached bindings for a principal pub fn get_bindings(&self, principal: &PrincipalRef) -> Option> { let key = principal.to_string(); self.bindings.get(&key).and_then(|entry| { if entry.fetched_at.elapsed() < self.config.binding_ttl { Some(entry.bindings.clone()) } else { None } }) } /// Cache bindings for a principal pub fn put_bindings(&self, principal: &PrincipalRef, bindings: Vec) { // Avoid negative caching for principals with no bindings. Without a // cross-node invalidation bus, caching an empty result can hold a // stale deny long after a new binding has been created elsewhere. if bindings.is_empty() { self.invalidate_bindings(principal); return; } // Evict if at capacity if self.bindings.len() >= self.config.max_binding_entries { self.evict_expired_bindings(); } // Extract scopes from bindings for scope-based invalidation let scopes: Vec = bindings.iter().map(|b| b.scope.clone()).collect(); let key = principal.to_string(); self.bindings.insert( key, CachedBindings { bindings, fetched_at: Instant::now(), scopes, }, ); } /// Invalidate bindings for a principal pub fn invalidate_bindings(&self, principal: &PrincipalRef) { let key = principal.to_string(); self.bindings.remove(&key); } /// Invalidate all bindings pub fn invalidate_all_bindings(&self) { self.bindings.clear(); } /// Invalidate bindings that cover a specific scope pub fn invalidate_bindings_by_scope(&self, scope: &Scope) { self.bindings.retain(|_, v| { // Remove if any binding's scope matches or is a child of the given scope !v.scopes.iter().any(|s| scope.contains(s)) }); } /// Handle an invalidation event pub fn handle_invalidation(&self, event: &CacheInvalidation) { match event { CacheInvalidation::Principal(id) => { // Try to parse as different principal types // This is a simplified approach - in production you might want // to include the kind in the event self.bindings.remove(id); } CacheInvalidation::Role(name) => { self.invalidate_role(name); } CacheInvalidation::Scope(scope) => { self.invalidate_bindings_by_scope(scope); } CacheInvalidation::All => { self.invalidate_all(); } } } /// Get a cached role pub fn get_role(&self, name: &str) -> Option { self.roles.get(name).and_then(|entry| { if entry.fetched_at.elapsed() < self.config.role_ttl { Some(entry.role.clone()) } else { None } }) } /// Cache a role pub fn put_role(&self, role: Role) { // Evict if at capacity if self.roles.len() >= self.config.max_role_entries { self.evict_expired_roles(); } let name = role.name.clone(); self.roles.insert( name, CachedRole { role, fetched_at: Instant::now(), }, ); } /// Invalidate a cached role pub fn invalidate_role(&self, name: &str) { self.roles.remove(name); } /// Invalidate all roles pub fn invalidate_all_roles(&self) { self.roles.clear(); } /// Invalidate all cached data pub fn invalidate_all(&self) { self.bindings.clear(); self.roles.clear(); } /// Get cache statistics pub fn stats(&self) -> CacheStats { CacheStats { binding_entries: self.bindings.len(), role_entries: self.roles.len(), } } /// Evict expired binding entries fn evict_expired_bindings(&self) { let ttl = self.config.binding_ttl; self.bindings.retain(|_, v| v.fetched_at.elapsed() < ttl); } /// Evict expired role entries fn evict_expired_roles(&self) { let ttl = self.config.role_ttl; self.roles.retain(|_, v| v.fetched_at.elapsed() < ttl); } } /// Start listening for invalidation events from an external source /// /// This function spawns a background task that listens for invalidation /// events and applies them to the cache. pub fn start_invalidation_listener( cache: Arc, source: Arc, ) -> tokio::task::JoinHandle<()> { let mut rx = source.subscribe(); tokio::spawn(async move { loop { match rx.recv().await { Ok(event) => { tracing::debug!("Received cache invalidation event: {:?}", event); cache.handle_invalidation(&event); } Err(broadcast::error::RecvError::Closed) => { tracing::info!("Cache invalidation source closed"); break; } Err(broadcast::error::RecvError::Lagged(n)) => { tracing::warn!("Lagged {} invalidation events, invalidating all", n); cache.invalidate_all(); } } } }) } /// A simple in-memory invalidation source for testing or local use pub struct LocalInvalidationSource { tx: broadcast::Sender, } impl LocalInvalidationSource { /// Create a new local invalidation source pub fn new() -> Self { let (tx, _) = broadcast::channel(256); Self { tx } } /// Send an invalidation event pub fn invalidate(&self, event: CacheInvalidation) { let _ = self.tx.send(event); } /// Invalidate a specific principal pub fn invalidate_principal(&self, principal_id: &str) { self.invalidate(CacheInvalidation::Principal(principal_id.to_string())); } /// Invalidate a specific role pub fn invalidate_role(&self, role_name: &str) { self.invalidate(CacheInvalidation::Role(role_name.to_string())); } /// Invalidate a scope pub fn invalidate_scope(&self, scope: Scope) { self.invalidate(CacheInvalidation::Scope(scope)); } /// Invalidate all pub fn invalidate_all(&self) { self.invalidate(CacheInvalidation::All); } } impl Default for LocalInvalidationSource { fn default() -> Self { Self::new() } } impl CacheInvalidationSource for LocalInvalidationSource { fn subscribe(&self) -> broadcast::Receiver { self.tx.subscribe() } } /// Cache statistics #[derive(Debug, Clone)] pub struct CacheStats { pub binding_entries: usize, pub role_entries: usize, } #[cfg(test)] mod tests { use super::*; use iam_types::{Permission, Scope}; #[test] fn test_binding_cache() { let cache = PolicyCache::default_config(); let principal = PrincipalRef::user("alice"); // Initially empty assert!(cache.get_bindings(&principal).is_none()); // Add bindings let bindings = vec![PolicyBinding::new( "b1", principal.clone(), "roles/Admin", Scope::System, )]; cache.put_bindings(&principal, bindings.clone()); // Should be cached let cached = cache.get_bindings(&principal).unwrap(); assert_eq!(cached.len(), 1); // Invalidate cache.invalidate_bindings(&principal); assert!(cache.get_bindings(&principal).is_none()); } #[test] fn test_role_cache() { let cache = PolicyCache::default_config(); let role = Role::new("TestRole", Scope::System, vec![Permission::wildcard()]); // Initially empty assert!(cache.get_role("TestRole").is_none()); // Add role cache.put_role(role.clone()); // Should be cached let cached = cache.get_role("TestRole").unwrap(); assert_eq!(cached.name, "TestRole"); // Invalidate cache.invalidate_role("TestRole"); assert!(cache.get_role("TestRole").is_none()); } #[test] fn test_cache_stats() { let cache = PolicyCache::default_config(); cache.put_bindings(&PrincipalRef::user("alice"), vec![]); cache.put_role(Role::new("Role1", Scope::System, vec![])); let stats = cache.stats(); assert_eq!(stats.binding_entries, 0); assert_eq!(stats.role_entries, 1); } #[test] fn test_empty_bindings_are_not_cached() { let cache = PolicyCache::default_config(); let alice = PrincipalRef::user("alice"); cache.put_bindings(&alice, vec![]); assert!(cache.get_bindings(&alice).is_none()); assert_eq!(cache.stats().binding_entries, 0); } #[test] fn test_scope_invalidation() { let cache = PolicyCache::default_config(); // Add bindings for different scopes let alice = PrincipalRef::user("alice"); let bob = PrincipalRef::user("bob"); let alice_bindings = vec![PolicyBinding::new( "b1", alice.clone(), "roles/Admin", Scope::org("org-1"), )]; let bob_bindings = vec![PolicyBinding::new( "b2", bob.clone(), "roles/Viewer", Scope::project("proj-1", "org-2"), )]; cache.put_bindings(&alice, alice_bindings); cache.put_bindings(&bob, bob_bindings); assert_eq!(cache.stats().binding_entries, 2); // Invalidate org-1 scope should remove alice's bindings cache.invalidate_bindings_by_scope(&Scope::org("org-1")); // Alice's bindings should be gone assert!(cache.get_bindings(&alice).is_none()); // Bob's bindings should still be there assert!(cache.get_bindings(&bob).is_some()); } #[test] fn test_handle_invalidation_event() { let cache = PolicyCache::default_config(); let alice = PrincipalRef::user("alice"); let bindings = vec![PolicyBinding::new( "b1", alice.clone(), "roles/Admin", Scope::System, )]; cache.put_bindings(&alice, bindings); cache.put_role(Role::new("TestRole", Scope::System, vec![])); assert_eq!(cache.stats().binding_entries, 1); assert_eq!(cache.stats().role_entries, 1); // Handle principal invalidation cache.handle_invalidation(&CacheInvalidation::Principal(alice.to_string())); assert!(cache.get_bindings(&alice).is_none()); assert!(cache.get_role("TestRole").is_some()); // Handle role invalidation cache.handle_invalidation(&CacheInvalidation::Role("TestRole".to_string())); assert!(cache.get_role("TestRole").is_none()); } #[test] fn test_local_invalidation_source() { let source = LocalInvalidationSource::new(); let mut rx = source.subscribe(); source.invalidate_principal("alice"); // Should receive the event (non-blocking check) let event = rx.try_recv(); assert!(event.is_ok()); match event.unwrap() { CacheInvalidation::Principal(id) => assert_eq!(id, "alice"), _ => panic!("Expected Principal invalidation"), } } #[tokio::test] async fn test_invalidation_listener() { let cache = Arc::new(PolicyCache::default_config()); let source = Arc::new(LocalInvalidationSource::new()); // Add some data to cache let alice = PrincipalRef::user("alice"); cache.put_bindings(&alice, vec![]); cache.put_role(Role::new("TestRole", Scope::System, vec![])); assert_eq!(cache.stats().binding_entries, 0); assert_eq!(cache.stats().role_entries, 1); // Start listener let _handle = start_invalidation_listener(cache.clone(), source.clone()); // Give the listener time to start tokio::time::sleep(std::time::Duration::from_millis(10)).await; // Send invalidation source.invalidate_all(); // Give time for event to be processed tokio::time::sleep(std::time::Duration::from_millis(50)).await; // Cache should be empty assert_eq!(cache.stats().binding_entries, 0); assert_eq!(cache.stats().role_entries, 0); } }