//! Internal token service //! //! Issues and verifies internal tokens for service-to-service authentication. //! Can optionally use TSO (Timestamp Oracle) for consistent timestamps. //! //! ## Key Rotation //! //! The service supports multiple signing keys for seamless rotation: //! - New tokens are always signed with the "active" key //! - Old tokens can be verified with "deprecated" keys during grace period //! - "Retired" keys are no longer used for verification use std::collections::HashMap; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use serde::{Deserialize, Serialize}; use tokio::sync::RwLock; use iam_types::{ AuthMethod, Error, IamError, InternalTokenClaims, Principal, Result, Scope, }; /// Key status for rotation management #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum KeyStatus { /// Key is actively used for signing new tokens Active, /// Key is deprecated but still valid for verification (grace period) Deprecated, /// Key is retired and no longer valid Retired, } /// Managed signing key with status and timestamps #[derive(Clone)] pub struct ManagedKey { /// The underlying signing key pub key: SigningKey, /// Key status pub status: KeyStatus, /// When the key was created pub created_at: u64, /// When the key was deprecated (if applicable) pub deprecated_at: Option, /// When the key was retired (if applicable) pub retired_at: Option, } /// Token signing key #[derive(Clone)] pub struct SigningKey { /// Key ID pub kid: String, /// Secret key bytes secret: Vec, } impl SigningKey { /// Create a new signing key pub fn new(kid: impl Into, secret: impl Into>) -> Self { Self { kid: kid.into(), secret: secret.into(), } } /// Generate a random signing key pub fn generate(kid: impl Into) -> Self { use rand::RngCore; let mut secret = vec![0u8; 32]; rand::thread_rng().fill_bytes(&mut secret); Self::new(kid, secret) } /// Sign data using HMAC-SHA256 pub fn sign(&self, data: &[u8]) -> Vec { use hmac::{Hmac, Mac}; use sha2::Sha256; type HmacSha256 = Hmac; let mut mac = HmacSha256::new_from_slice(&self.secret).expect("HMAC can take key of any size"); mac.update(data); mac.finalize().into_bytes().to_vec() } /// Verify a signature pub fn verify(&self, data: &[u8], signature: &[u8]) -> bool { use hmac::{Hmac, Mac}; use sha2::Sha256; type HmacSha256 = Hmac; let mut mac = HmacSha256::new_from_slice(&self.secret).expect("HMAC can take key of any size"); mac.update(data); mac.verify_slice(signature).is_ok() } } impl ManagedKey { /// Create a new active managed key pub fn new_active(key: SigningKey, created_at: u64) -> Self { Self { key, status: KeyStatus::Active, created_at, deprecated_at: None, retired_at: None, } } /// Deprecate this key pub fn deprecate(&mut self, at: u64) { if self.status == KeyStatus::Active { self.status = KeyStatus::Deprecated; self.deprecated_at = Some(at); } } /// Retire this key pub fn retire(&mut self, at: u64) { self.status = KeyStatus::Retired; self.retired_at = Some(at); } /// Check if key can be used for signing (must be Active) pub fn can_sign(&self) -> bool { self.status == KeyStatus::Active } /// Check if key can be used for verification (Active or Deprecated) pub fn can_verify(&self) -> bool { matches!(self.status, KeyStatus::Active | KeyStatus::Deprecated) } } /// Configuration for key rotation #[derive(Debug, Clone)] pub struct KeyRotationConfig { /// How long to keep deprecated keys valid for verification pub grace_period: Duration, /// How often to check for keys to retire pub cleanup_interval: Duration, /// Prefix for generated key IDs pub key_id_prefix: String, } impl Default for KeyRotationConfig { fn default() -> Self { Self { grace_period: Duration::from_secs(86400 * 7), // 7 days cleanup_interval: Duration::from_secs(3600), // 1 hour key_id_prefix: "key".into(), } } } /// Key rotation manager /// /// Manages the lifecycle of signing keys: /// 1. Generate new keys /// 2. Deprecate old keys (enter grace period) /// 3. Retire deprecated keys after grace period pub struct KeyRotationManager { config: KeyRotationConfig, /// All managed keys, keyed by kid keys: RwLock>, /// ID of the currently active key active_key_id: RwLock>, } impl KeyRotationManager { /// Create a new key rotation manager pub fn new(config: KeyRotationConfig) -> Self { Self { config, keys: RwLock::new(HashMap::new()), active_key_id: RwLock::new(None), } } /// Create with default config pub fn with_defaults() -> Self { Self::new(KeyRotationConfig::default()) } /// Initialize with an existing key pub async fn init_with_key(&self, key: SigningKey, created_at: u64) { let kid = key.kid.clone(); let managed = ManagedKey::new_active(key, created_at); let mut keys = self.keys.write().await; keys.insert(kid.clone(), managed); let mut active_id = self.active_key_id.write().await; *active_id = Some(kid); } /// Generate and activate a new key, deprecating the current active key pub async fn rotate(&self) -> Result { let now = current_timestamp(); let new_kid = self.generate_key_id(); let new_key = SigningKey::generate(&new_kid); let new_managed = ManagedKey::new_active(new_key, now); let mut keys = self.keys.write().await; let mut active_id = self.active_key_id.write().await; // Deprecate the current active key if let Some(old_kid) = active_id.as_ref() { if let Some(old_key) = keys.get_mut(old_kid) { old_key.deprecate(now); } } // Insert and activate new key keys.insert(new_kid.clone(), new_managed); *active_id = Some(new_kid.clone()); Ok(new_kid) } /// Get the active signing key pub async fn get_active_key(&self) -> Option { let active_id = self.active_key_id.read().await; let keys = self.keys.read().await; active_id .as_ref() .and_then(|kid| keys.get(kid)) .filter(|k| k.can_sign()) .map(|k| k.key.clone()) } /// Get a key by ID for verification (if it can verify) pub async fn get_key_for_verify(&self, kid: &str) -> Option { let keys = self.keys.read().await; keys.get(kid) .filter(|k| k.can_verify()) .map(|k| k.key.clone()) } /// Get all keys that can be used for verification pub async fn get_verification_keys(&self) -> Vec { let keys = self.keys.read().await; keys.values() .filter(|k| k.can_verify()) .map(|k| k.key.clone()) .collect() } /// Retire keys that have exceeded the grace period pub async fn cleanup_expired(&self) -> usize { let now = current_timestamp(); let grace_secs = self.config.grace_period.as_secs(); let mut retired = 0; let mut keys = self.keys.write().await; for managed in keys.values_mut() { if managed.status == KeyStatus::Deprecated { if let Some(deprecated_at) = managed.deprecated_at { // Use >= to allow immediate expiry when grace_period is 0 if now >= deprecated_at + grace_secs { managed.retire(now); retired += 1; } } } } retired } /// Remove retired keys from memory pub async fn purge_retired(&self) -> usize { let mut keys = self.keys.write().await; let before = keys.len(); keys.retain(|_, v| v.status != KeyStatus::Retired); before - keys.len() } /// Get key statistics pub async fn stats(&self) -> KeyRotationStats { let keys = self.keys.read().await; let mut active = 0; let mut deprecated = 0; let mut retired = 0; for key in keys.values() { match key.status { KeyStatus::Active => active += 1, KeyStatus::Deprecated => deprecated += 1, KeyStatus::Retired => retired += 1, } } KeyRotationStats { active, deprecated, retired, total: keys.len(), } } /// Generate a unique key ID fn generate_key_id(&self) -> String { let timestamp = current_timestamp(); let mut random = [0u8; 4]; rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut random); format!( "{}-{}-{}", self.config.key_id_prefix, timestamp, URL_SAFE_NO_PAD.encode(random) ) } } /// Key rotation statistics #[derive(Debug, Clone)] pub struct KeyRotationStats { pub active: usize, pub deprecated: usize, pub retired: usize, pub total: usize, } /// Get current Unix timestamp fn current_timestamp() -> u64 { SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs() } /// Configuration for internal token service #[derive(Clone)] pub struct InternalTokenConfig { /// Signing keys (for rotation support) pub signing_keys: Vec, /// Default token TTL pub default_ttl: Duration, /// Maximum token TTL pub max_ttl: Duration, /// Token issuer identifier pub issuer: String, } impl InternalTokenConfig { /// Create a new config with a single key pub fn new(signing_key: SigningKey, issuer: impl Into) -> Self { Self { signing_keys: vec![signing_key], default_ttl: Duration::from_secs(3600), // 1 hour max_ttl: Duration::from_secs(86400 * 7), // 7 days issuer: issuer.into(), } } /// Add a signing key (for rotation) pub fn add_key(mut self, key: SigningKey) -> Self { self.signing_keys.push(key); self } /// Set default TTL pub fn with_default_ttl(mut self, ttl: Duration) -> Self { self.default_ttl = ttl; self } /// Set maximum TTL pub fn with_max_ttl(mut self, ttl: Duration) -> Self { self.max_ttl = ttl; self } } /// TSO client trait for timestamp generation #[async_trait::async_trait] pub trait TsoClient: Send + Sync { /// Get current timestamp async fn get_timestamp(&self) -> Result; } /// Internal token service pub struct InternalTokenService { config: InternalTokenConfig, tso_client: Option>, } impl InternalTokenService { /// Create a new internal token service pub fn new(config: InternalTokenConfig) -> Self { Self { config, tso_client: None, } } /// Create with TSO client for consistent timestamps pub fn with_tso(mut self, client: Arc) -> Self { self.tso_client = Some(client); self } /// Issue a new internal token pub async fn issue( &self, principal: &Principal, roles: Vec, scope: Scope, ttl: Option, ) -> Result { let ttl = ttl.unwrap_or(self.config.default_ttl); if ttl > self.config.max_ttl { return Err(Error::Iam(IamError::InvalidToken(format!( "TTL exceeds maximum: {:?} > {:?}", ttl, self.config.max_ttl )))); } let now = self.get_timestamp().await?; let exp = now + ttl.as_secs(); let session_id = generate_session_id(); let claims = InternalTokenClaims::new( &principal.id, principal.kind.clone(), &principal.name, scope, &session_id, ) .with_roles(roles) .with_timestamps(now, exp) .with_auth_method(AuthMethod::Internal); // Add optional fields let claims = match &principal.org_id { Some(org) => claims.with_org_id(org), None => claims, }; let claims = match &principal.project_id { Some(proj) => claims.with_project_id(proj), None => claims, }; let claims = match &principal.node_id { Some(node) => claims.with_node_id(node), None => claims, }; let token = self.encode_token(&claims)?; Ok(IssuedToken { token, claims, expires_at: exp, }) } /// Verify an internal token pub async fn verify(&self, token: &str) -> Result { let claims = self.decode_token(token)?; // Check expiration let now = self.get_timestamp().await?; if claims.is_expired(now) { return Err(Error::Iam(IamError::TokenExpired)); } Ok(claims) } /// Get current timestamp (from TSO or system time) async fn get_timestamp(&self) -> Result { match &self.tso_client { Some(tso) => tso.get_timestamp().await, None => Ok(SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs()), } } /// Encode claims into a signed token fn encode_token(&self, claims: &InternalTokenClaims) -> Result { let signing_key = self .config .signing_keys .first() .ok_or_else(|| Error::Internal("No signing key configured".into()))?; // Serialize claims let claims_json = serde_json::to_vec(claims).map_err(|e| Error::Serialization(e.to_string()))?; // Create header let header = TokenHeader { alg: "HS256".into(), kid: signing_key.kid.clone(), iss: self.config.issuer.clone(), }; let header_json = serde_json::to_vec(&header).map_err(|e| Error::Serialization(e.to_string()))?; // Encode parts let header_b64 = URL_SAFE_NO_PAD.encode(&header_json); let claims_b64 = URL_SAFE_NO_PAD.encode(&claims_json); // Sign let signing_input = format!("{}.{}", header_b64, claims_b64); let signature = signing_key.sign(signing_input.as_bytes()); let signature_b64 = URL_SAFE_NO_PAD.encode(&signature); Ok(format!("{}.{}.{}", header_b64, claims_b64, signature_b64)) } /// Decode and verify a token fn decode_token(&self, token: &str) -> Result { let parts: Vec<&str> = token.split('.').collect(); if parts.len() != 3 { return Err(Error::Iam(IamError::InvalidToken( "Invalid token format".into(), ))); } let header_json = URL_SAFE_NO_PAD .decode(parts[0]) .map_err(|e| Error::Iam(IamError::InvalidToken(e.to_string())))?; let header: TokenHeader = serde_json::from_slice(&header_json) .map_err(|e| Error::Iam(IamError::InvalidToken(e.to_string())))?; // Find signing key let signing_key = self .config .signing_keys .iter() .find(|k| k.kid == header.kid) .ok_or_else(|| { Error::Iam(IamError::InvalidToken(format!( "Unknown key ID: {}", header.kid ))) })?; // Verify signature let signature = URL_SAFE_NO_PAD .decode(parts[2]) .map_err(|e| Error::Iam(IamError::InvalidToken(e.to_string())))?; let signing_input = format!("{}.{}", parts[0], parts[1]); if !signing_key.verify(signing_input.as_bytes(), &signature) { return Err(Error::Iam(IamError::InvalidToken( "Invalid signature".into(), ))); } // Decode claims let claims_json = URL_SAFE_NO_PAD .decode(parts[1]) .map_err(|e| Error::Iam(IamError::InvalidToken(e.to_string())))?; let claims: InternalTokenClaims = serde_json::from_slice(&claims_json) .map_err(|e| Error::Iam(IamError::InvalidToken(e.to_string())))?; Ok(claims) } } /// Token header #[derive(Debug, Serialize, Deserialize)] struct TokenHeader { alg: String, kid: String, iss: String, } /// Result of issuing a token #[derive(Debug, Clone)] pub struct IssuedToken { /// The token string pub token: String, /// The token claims pub claims: InternalTokenClaims, /// Expiration timestamp (Unix seconds) pub expires_at: u64, } /// Generate a random session ID fn generate_session_id() -> String { use rand::RngCore; let mut bytes = [0u8; 16]; rand::thread_rng().fill_bytes(&mut bytes); URL_SAFE_NO_PAD.encode(bytes) } #[cfg(test)] mod tests { use super::*; fn test_config() -> InternalTokenConfig { let key = SigningKey::generate("test-key-1"); InternalTokenConfig::new(key, "iam-test") } #[tokio::test] async fn test_issue_and_verify() { let service = InternalTokenService::new(test_config()); let principal = Principal::new_user("alice", "Alice Smith"); let roles = vec!["roles/ProjectAdmin".into()]; let scope = Scope::project("proj-1", "org-1"); let issued = service.issue(&principal, roles, scope, None).await.unwrap(); assert!(!issued.token.is_empty()); assert_eq!(issued.claims.principal_id, "alice"); // Verify let verified = service.verify(&issued.token).await.unwrap(); assert_eq!(verified.principal_id, "alice"); assert_eq!(verified.roles.len(), 1); } #[tokio::test] async fn test_invalid_signature() { let service = InternalTokenService::new(test_config()); let principal = Principal::new_user("alice", "Alice"); let issued = service .issue(&principal, vec![], Scope::System, None) .await .unwrap(); // Tamper with token let parts: Vec<&str> = issued.token.split('.').collect(); let tampered = format!("{}.{}.invalid", parts[0], parts[1]); let result = service.verify(&tampered).await; assert!(result.is_err()); } #[tokio::test] async fn test_ttl_limit() { let config = test_config().with_max_ttl(Duration::from_secs(3600)); let service = InternalTokenService::new(config); let principal = Principal::new_user("alice", "Alice"); let result = service .issue( &principal, vec![], Scope::System, Some(Duration::from_secs(86400)), // 24 hours - exceeds max ) .await; assert!(result.is_err()); } #[test] fn test_signing_key() { let key = SigningKey::generate("test"); let data = b"hello world"; let signature = key.sign(data); assert!(key.verify(data, &signature)); assert!(!key.verify(b"tampered", &signature)); } // Key rotation tests #[tokio::test] async fn test_key_rotation_init() { let manager = KeyRotationManager::with_defaults(); let key = SigningKey::generate("initial-key"); let now = current_timestamp(); manager.init_with_key(key.clone(), now).await; let active = manager.get_active_key().await; assert!(active.is_some()); assert_eq!(active.unwrap().kid, "initial-key"); let stats = manager.stats().await; assert_eq!(stats.active, 1); assert_eq!(stats.deprecated, 0); assert_eq!(stats.total, 1); } #[tokio::test] async fn test_key_rotation_rotate() { let manager = KeyRotationManager::with_defaults(); let key = SigningKey::generate("initial-key"); let now = current_timestamp(); manager.init_with_key(key.clone(), now).await; // Rotate to new key let new_kid = manager.rotate().await.unwrap(); assert!(new_kid.starts_with("key-")); // Active key should be the new one let active = manager.get_active_key().await; assert!(active.is_some()); assert_eq!(active.unwrap().kid, new_kid); // Old key should still be available for verification let old_key = manager.get_key_for_verify("initial-key").await; assert!(old_key.is_some()); let stats = manager.stats().await; assert_eq!(stats.active, 1); assert_eq!(stats.deprecated, 1); assert_eq!(stats.total, 2); } #[tokio::test] async fn test_key_rotation_verify_with_old_key() { let manager = KeyRotationManager::with_defaults(); let key = SigningKey::generate("initial-key"); let now = current_timestamp(); manager.init_with_key(key.clone(), now).await; // Sign some data with the active key let data = b"test data"; let signature = manager.get_active_key().await.unwrap().sign(data); // Rotate to new key manager.rotate().await.unwrap(); // Old key should still be able to verify let old_key = manager.get_key_for_verify("initial-key").await.unwrap(); assert!(old_key.verify(data, &signature)); } #[tokio::test] async fn test_key_rotation_multiple_rotations() { let manager = KeyRotationManager::with_defaults(); let key = SigningKey::generate("key-0"); let now = current_timestamp(); manager.init_with_key(key, now).await; // Rotate 3 times for _ in 0..3 { manager.rotate().await.unwrap(); } let stats = manager.stats().await; assert_eq!(stats.active, 1); assert_eq!(stats.deprecated, 3); // key-0, key-1, key-2 are deprecated assert_eq!(stats.total, 4); // All deprecated keys should still be verifiable assert!(manager.get_key_for_verify("key-0").await.is_some()); } #[tokio::test] async fn test_key_status_transitions() { let key = SigningKey::generate("test-key"); let now = current_timestamp(); let mut managed = ManagedKey::new_active(key, now); // Initially active assert!(managed.can_sign()); assert!(managed.can_verify()); // Deprecate managed.deprecate(now); assert!(!managed.can_sign()); assert!(managed.can_verify()); assert_eq!(managed.status, KeyStatus::Deprecated); // Retire managed.retire(now); assert!(!managed.can_sign()); assert!(!managed.can_verify()); assert_eq!(managed.status, KeyStatus::Retired); } #[tokio::test] async fn test_key_rotation_cleanup_expired() { // Use very short grace period for testing let config = KeyRotationConfig { grace_period: Duration::from_secs(0), // immediate expiry cleanup_interval: Duration::from_secs(60), key_id_prefix: "test".into(), }; let manager = KeyRotationManager::new(config); let key = SigningKey::generate("initial"); let now = current_timestamp(); manager.init_with_key(key, now).await; manager.rotate().await.unwrap(); // Wait a moment for grace period to elapse tokio::time::sleep(Duration::from_millis(10)).await; let retired = manager.cleanup_expired().await; assert_eq!(retired, 1); let stats = manager.stats().await; assert_eq!(stats.deprecated, 0); assert_eq!(stats.retired, 1); } #[tokio::test] async fn test_key_rotation_purge_retired() { let config = KeyRotationConfig { grace_period: Duration::from_secs(0), cleanup_interval: Duration::from_secs(60), key_id_prefix: "test".into(), }; let manager = KeyRotationManager::new(config); let key = SigningKey::generate("initial"); let now = current_timestamp(); manager.init_with_key(key, now).await; manager.rotate().await.unwrap(); tokio::time::sleep(Duration::from_millis(10)).await; manager.cleanup_expired().await; let purged = manager.purge_retired().await; assert_eq!(purged, 1); let stats = manager.stats().await; assert_eq!(stats.total, 1); assert_eq!(stats.active, 1); } #[tokio::test] async fn test_get_verification_keys() { let manager = KeyRotationManager::with_defaults(); let key = SigningKey::generate("initial"); let now = current_timestamp(); manager.init_with_key(key, now).await; manager.rotate().await.unwrap(); manager.rotate().await.unwrap(); let verification_keys = manager.get_verification_keys().await; assert_eq!(verification_keys.len(), 3); } }