photoncloud-monorepo/iam/crates/iam-authz/src/cache.rs

509 lines
15 KiB
Rust

//! Policy cache for authorization
//!
//! Caches policy bindings and roles to reduce storage lookups.
//! Supports both TTL-based expiration and event-driven invalidation.
use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use tokio::sync::broadcast;
use iam_types::{PolicyBinding, PrincipalRef, Role, Scope};
/// Cache invalidation event
#[derive(Debug, Clone)]
pub enum CacheInvalidation {
/// Invalidate bindings for a specific principal
Principal(String),
/// Invalidate a specific role
Role(String),
/// Invalidate all bindings within a scope
Scope(Scope),
/// Invalidate all cached data
All,
}
/// Source of cache invalidation events (e.g., from Chainfire watch)
pub trait CacheInvalidationSource: Send + Sync {
/// Subscribe to invalidation events
fn subscribe(&self) -> broadcast::Receiver<CacheInvalidation>;
}
/// Configuration for the policy cache
#[derive(Debug, Clone)]
pub struct PolicyCacheConfig {
/// TTL for cached bindings
pub binding_ttl: Duration,
/// TTL for cached roles
pub role_ttl: Duration,
/// Maximum number of binding entries
pub max_binding_entries: usize,
/// Maximum number of role entries
pub max_role_entries: usize,
}
impl Default for PolicyCacheConfig {
fn default() -> Self {
Self {
binding_ttl: Duration::from_secs(300), // 5 minutes
role_ttl: Duration::from_secs(600), // 10 minutes
max_binding_entries: 10000,
max_role_entries: 1000,
}
}
}
/// Cached bindings for a principal
struct CachedBindings {
bindings: Vec<PolicyBinding>,
fetched_at: Instant,
/// Scopes covered by these bindings (for scope-based invalidation)
scopes: Vec<Scope>,
}
/// Cached role
struct CachedRole {
role: Role,
fetched_at: Instant,
}
/// Policy cache
pub struct PolicyCache {
config: PolicyCacheConfig,
/// Bindings cache keyed by principal reference
bindings: DashMap<String, CachedBindings>,
/// Roles cache keyed by role name
roles: DashMap<String, CachedRole>,
}
impl PolicyCache {
/// Create a new policy cache
pub fn new(config: PolicyCacheConfig) -> Self {
Self {
config,
bindings: DashMap::new(),
roles: DashMap::new(),
}
}
/// Create with default config
pub fn default_config() -> Self {
Self::new(PolicyCacheConfig::default())
}
/// Get cached bindings for a principal
pub fn get_bindings(&self, principal: &PrincipalRef) -> Option<Vec<PolicyBinding>> {
let key = principal.to_string();
self.bindings.get(&key).and_then(|entry| {
if entry.fetched_at.elapsed() < self.config.binding_ttl {
Some(entry.bindings.clone())
} else {
None
}
})
}
/// Cache bindings for a principal
pub fn put_bindings(&self, principal: &PrincipalRef, bindings: Vec<PolicyBinding>) {
// Avoid negative caching for principals with no bindings. Without a
// cross-node invalidation bus, caching an empty result can hold a
// stale deny long after a new binding has been created elsewhere.
if bindings.is_empty() {
self.invalidate_bindings(principal);
return;
}
// Evict if at capacity
if self.bindings.len() >= self.config.max_binding_entries {
self.evict_expired_bindings();
}
// Extract scopes from bindings for scope-based invalidation
let scopes: Vec<Scope> = bindings.iter().map(|b| b.scope.clone()).collect();
let key = principal.to_string();
self.bindings.insert(
key,
CachedBindings {
bindings,
fetched_at: Instant::now(),
scopes,
},
);
}
/// Invalidate bindings for a principal
pub fn invalidate_bindings(&self, principal: &PrincipalRef) {
let key = principal.to_string();
self.bindings.remove(&key);
}
/// Invalidate all bindings
pub fn invalidate_all_bindings(&self) {
self.bindings.clear();
}
/// Invalidate bindings that cover a specific scope
pub fn invalidate_bindings_by_scope(&self, scope: &Scope) {
self.bindings.retain(|_, v| {
// Remove if any binding's scope matches or is a child of the given scope
!v.scopes.iter().any(|s| scope.contains(s))
});
}
/// Handle an invalidation event
pub fn handle_invalidation(&self, event: &CacheInvalidation) {
match event {
CacheInvalidation::Principal(id) => {
// Try to parse as different principal types
// This is a simplified approach - in production you might want
// to include the kind in the event
self.bindings.remove(id);
}
CacheInvalidation::Role(name) => {
self.invalidate_role(name);
}
CacheInvalidation::Scope(scope) => {
self.invalidate_bindings_by_scope(scope);
}
CacheInvalidation::All => {
self.invalidate_all();
}
}
}
/// Get a cached role
pub fn get_role(&self, name: &str) -> Option<Role> {
self.roles.get(name).and_then(|entry| {
if entry.fetched_at.elapsed() < self.config.role_ttl {
Some(entry.role.clone())
} else {
None
}
})
}
/// Cache a role
pub fn put_role(&self, role: Role) {
// Evict if at capacity
if self.roles.len() >= self.config.max_role_entries {
self.evict_expired_roles();
}
let name = role.name.clone();
self.roles.insert(
name,
CachedRole {
role,
fetched_at: Instant::now(),
},
);
}
/// Invalidate a cached role
pub fn invalidate_role(&self, name: &str) {
self.roles.remove(name);
}
/// Invalidate all roles
pub fn invalidate_all_roles(&self) {
self.roles.clear();
}
/// Invalidate all cached data
pub fn invalidate_all(&self) {
self.bindings.clear();
self.roles.clear();
}
/// Get cache statistics
pub fn stats(&self) -> CacheStats {
CacheStats {
binding_entries: self.bindings.len(),
role_entries: self.roles.len(),
}
}
/// Evict expired binding entries
fn evict_expired_bindings(&self) {
let ttl = self.config.binding_ttl;
self.bindings.retain(|_, v| v.fetched_at.elapsed() < ttl);
}
/// Evict expired role entries
fn evict_expired_roles(&self) {
let ttl = self.config.role_ttl;
self.roles.retain(|_, v| v.fetched_at.elapsed() < ttl);
}
}
/// Start listening for invalidation events from an external source
///
/// This function spawns a background task that listens for invalidation
/// events and applies them to the cache.
pub fn start_invalidation_listener(
cache: Arc<PolicyCache>,
source: Arc<dyn CacheInvalidationSource>,
) -> tokio::task::JoinHandle<()> {
let mut rx = source.subscribe();
tokio::spawn(async move {
loop {
match rx.recv().await {
Ok(event) => {
tracing::debug!("Received cache invalidation event: {:?}", event);
cache.handle_invalidation(&event);
}
Err(broadcast::error::RecvError::Closed) => {
tracing::info!("Cache invalidation source closed");
break;
}
Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!("Lagged {} invalidation events, invalidating all", n);
cache.invalidate_all();
}
}
}
})
}
/// A simple in-memory invalidation source for testing or local use
pub struct LocalInvalidationSource {
tx: broadcast::Sender<CacheInvalidation>,
}
impl LocalInvalidationSource {
/// Create a new local invalidation source
pub fn new() -> Self {
let (tx, _) = broadcast::channel(256);
Self { tx }
}
/// Send an invalidation event
pub fn invalidate(&self, event: CacheInvalidation) {
let _ = self.tx.send(event);
}
/// Invalidate a specific principal
pub fn invalidate_principal(&self, principal_id: &str) {
self.invalidate(CacheInvalidation::Principal(principal_id.to_string()));
}
/// Invalidate a specific role
pub fn invalidate_role(&self, role_name: &str) {
self.invalidate(CacheInvalidation::Role(role_name.to_string()));
}
/// Invalidate a scope
pub fn invalidate_scope(&self, scope: Scope) {
self.invalidate(CacheInvalidation::Scope(scope));
}
/// Invalidate all
pub fn invalidate_all(&self) {
self.invalidate(CacheInvalidation::All);
}
}
impl Default for LocalInvalidationSource {
fn default() -> Self {
Self::new()
}
}
impl CacheInvalidationSource for LocalInvalidationSource {
fn subscribe(&self) -> broadcast::Receiver<CacheInvalidation> {
self.tx.subscribe()
}
}
/// Cache statistics
#[derive(Debug, Clone)]
pub struct CacheStats {
pub binding_entries: usize,
pub role_entries: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use iam_types::{Permission, Scope};
#[test]
fn test_binding_cache() {
let cache = PolicyCache::default_config();
let principal = PrincipalRef::user("alice");
// Initially empty
assert!(cache.get_bindings(&principal).is_none());
// Add bindings
let bindings = vec![PolicyBinding::new(
"b1",
principal.clone(),
"roles/Admin",
Scope::System,
)];
cache.put_bindings(&principal, bindings.clone());
// Should be cached
let cached = cache.get_bindings(&principal).unwrap();
assert_eq!(cached.len(), 1);
// Invalidate
cache.invalidate_bindings(&principal);
assert!(cache.get_bindings(&principal).is_none());
}
#[test]
fn test_role_cache() {
let cache = PolicyCache::default_config();
let role = Role::new("TestRole", Scope::System, vec![Permission::wildcard()]);
// Initially empty
assert!(cache.get_role("TestRole").is_none());
// Add role
cache.put_role(role.clone());
// Should be cached
let cached = cache.get_role("TestRole").unwrap();
assert_eq!(cached.name, "TestRole");
// Invalidate
cache.invalidate_role("TestRole");
assert!(cache.get_role("TestRole").is_none());
}
#[test]
fn test_cache_stats() {
let cache = PolicyCache::default_config();
cache.put_bindings(&PrincipalRef::user("alice"), vec![]);
cache.put_role(Role::new("Role1", Scope::System, vec![]));
let stats = cache.stats();
assert_eq!(stats.binding_entries, 0);
assert_eq!(stats.role_entries, 1);
}
#[test]
fn test_empty_bindings_are_not_cached() {
let cache = PolicyCache::default_config();
let alice = PrincipalRef::user("alice");
cache.put_bindings(&alice, vec![]);
assert!(cache.get_bindings(&alice).is_none());
assert_eq!(cache.stats().binding_entries, 0);
}
#[test]
fn test_scope_invalidation() {
let cache = PolicyCache::default_config();
// Add bindings for different scopes
let alice = PrincipalRef::user("alice");
let bob = PrincipalRef::user("bob");
let alice_bindings = vec![PolicyBinding::new(
"b1",
alice.clone(),
"roles/Admin",
Scope::org("org-1"),
)];
let bob_bindings = vec![PolicyBinding::new(
"b2",
bob.clone(),
"roles/Viewer",
Scope::project("proj-1", "org-2"),
)];
cache.put_bindings(&alice, alice_bindings);
cache.put_bindings(&bob, bob_bindings);
assert_eq!(cache.stats().binding_entries, 2);
// Invalidate org-1 scope should remove alice's bindings
cache.invalidate_bindings_by_scope(&Scope::org("org-1"));
// Alice's bindings should be gone
assert!(cache.get_bindings(&alice).is_none());
// Bob's bindings should still be there
assert!(cache.get_bindings(&bob).is_some());
}
#[test]
fn test_handle_invalidation_event() {
let cache = PolicyCache::default_config();
let alice = PrincipalRef::user("alice");
let bindings = vec![PolicyBinding::new(
"b1",
alice.clone(),
"roles/Admin",
Scope::System,
)];
cache.put_bindings(&alice, bindings);
cache.put_role(Role::new("TestRole", Scope::System, vec![]));
assert_eq!(cache.stats().binding_entries, 1);
assert_eq!(cache.stats().role_entries, 1);
// Handle principal invalidation
cache.handle_invalidation(&CacheInvalidation::Principal(alice.to_string()));
assert!(cache.get_bindings(&alice).is_none());
assert!(cache.get_role("TestRole").is_some());
// Handle role invalidation
cache.handle_invalidation(&CacheInvalidation::Role("TestRole".to_string()));
assert!(cache.get_role("TestRole").is_none());
}
#[test]
fn test_local_invalidation_source() {
let source = LocalInvalidationSource::new();
let mut rx = source.subscribe();
source.invalidate_principal("alice");
// Should receive the event (non-blocking check)
let event = rx.try_recv();
assert!(event.is_ok());
match event.unwrap() {
CacheInvalidation::Principal(id) => assert_eq!(id, "alice"),
_ => panic!("Expected Principal invalidation"),
}
}
#[tokio::test]
async fn test_invalidation_listener() {
let cache = Arc::new(PolicyCache::default_config());
let source = Arc::new(LocalInvalidationSource::new());
// Add some data to cache
let alice = PrincipalRef::user("alice");
cache.put_bindings(&alice, vec![]);
cache.put_role(Role::new("TestRole", Scope::System, vec![]));
assert_eq!(cache.stats().binding_entries, 0);
assert_eq!(cache.stats().role_entries, 1);
// Start listener
let _handle = start_invalidation_listener(cache.clone(), source.clone());
// Give the listener time to start
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
// Send invalidation
source.invalidate_all();
// Give time for event to be processed
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
// Cache should be empty
assert_eq!(cache.stats().binding_entries, 0);
assert_eq!(cache.stats().role_entries, 0);
}
}