use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; use aes_gcm::{aead::Aead, Aes256Gcm, Key, KeyInit, Nonce}; use argon2::{password_hash::{PasswordHasher, SaltString}, Argon2}; use base64::{engine::general_purpose::STANDARD, Engine}; use rand_core::{OsRng, RngCore}; use tonic::{Request, Response, Status}; use iam_store::CredentialStore; use iam_types::{Argon2Params, CredentialRecord}; use crate::proto::{ iam_credential_server::IamCredential, CreateS3CredentialRequest, CreateS3CredentialResponse, Credential, GetSecretKeyRequest, GetSecretKeyResponse, ListCredentialsRequest, ListCredentialsResponse, RevokeCredentialRequest, RevokeCredentialResponse, }; fn now_ts() -> u64 { SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap_or_default() .as_secs() } pub struct IamCredentialService { store: Arc, cipher: Aes256Gcm, key_id: String, } impl IamCredentialService { pub fn new(store: Arc, master_key: &[u8], key_id: &str) -> Result { if master_key.len() != 32 { return Err(Status::failed_precondition( "IAM_CRED_MASTER_KEY must be 32 bytes", )); } let cipher = Aes256Gcm::new(Key::::from_slice(master_key)); Ok(Self { store, cipher, key_id: key_id.to_string(), }) } fn generate_secret() -> (String, Vec) { let raw = uuid::Uuid::new_v4().as_bytes().to_vec(); let secret_b64 = STANDARD.encode(&raw); (secret_b64, raw) } fn hash_secret(raw: &[u8]) -> (String, Argon2Params) { let salt = SaltString::generate(&mut OsRng); let argon2 = Argon2::default(); let hash = argon2 .hash_password(raw, &salt) .expect("argon2 hash") .to_string(); let params = Argon2Params { m_cost_kib: argon2.params().m_cost(), t_cost: argon2.params().t_cost(), p_cost: argon2.params().p_cost(), salt_b64: salt.to_string(), }; (hash, params) } fn encrypt_secret(&self, raw: &[u8]) -> Result { let mut nonce_bytes = [0u8; 12]; OsRng.fill_bytes(&mut nonce_bytes); let nonce = Nonce::from_slice(&nonce_bytes); let ciphertext = self .cipher .encrypt(nonce, raw) .map_err(|e| Status::internal(format!("encrypt secret: {}", e)))?; let mut combined = nonce_bytes.to_vec(); combined.extend_from_slice(&ciphertext); Ok(STANDARD.encode(combined)) } fn decrypt_secret(&self, enc_b64: &str) -> Result, Status> { let data = STANDARD .decode(enc_b64) .map_err(|e| Status::internal(format!("invalid b64: {}", e)))?; if data.len() < 12 { return Err(Status::internal("ciphertext too short")); } let (nonce_bytes, ct) = data.split_at(12); let nonce = Nonce::from_slice(nonce_bytes); self.cipher .decrypt(nonce, ct) .map_err(|e| Status::internal(format!("decrypt failed: {}", e))) } } #[tonic::async_trait] impl IamCredential for IamCredentialService { async fn create_s3_credential( &self, request: Request, ) -> Result, Status> { let req = request.into_inner(); let now = now_ts(); let (secret_b64, raw_secret) = Self::generate_secret(); let (hash, kdf) = Self::hash_secret(&raw_secret); let secret_enc = self.encrypt_secret(&raw_secret)?; let access_key_id = format!("ak_{}", uuid::Uuid::new_v4()); let record = CredentialRecord { access_key_id: access_key_id.clone(), principal_id: req.principal_id.clone(), created_at: now, expires_at: req.expires_at, revoked: false, description: if req.description.is_empty() { None } else { Some(req.description) }, secret_hash: hash, secret_enc, key_id: self.key_id.clone(), version: 1, kdf, }; self.store .put(&record) .await .map_err(|e| Status::internal(format!("store credential: {}", e)))?; Ok(Response::new(CreateS3CredentialResponse { access_key_id, secret_key: secret_b64, created_at: now, expires_at: req.expires_at, })) } async fn get_secret_key( &self, request: Request, ) -> Result, Status> { let req = request.into_inner(); let record = match self.store.get(&req.access_key_id).await { Ok(Some((rec, _))) => rec, Ok(None) => return Err(Status::not_found("access key not found")), Err(e) => { return Err(Status::internal(format!( "failed to load credential: {}", e ))) } }; if record.revoked { return Err(Status::permission_denied("access key revoked")); } if let Some(exp) = record.expires_at { if now_ts() > exp { return Err(Status::permission_denied("access key expired")); } } let secret = self.decrypt_secret(&record.secret_enc)?; Ok(Response::new(GetSecretKeyResponse { secret_key: STANDARD.encode(secret), principal_id: record.principal_id, expires_at: record.expires_at, })) } async fn list_credentials( &self, request: Request, ) -> Result, Status> { let req = request.into_inner(); let items = self .store .list_for_principal(&req.principal_id, 1000) .await .map_err(|e| Status::internal(format!("list credentials: {}", e)))?; let creds: Vec = items .into_iter() .map(|c| Credential { access_key_id: c.access_key_id, principal_id: c.principal_id, created_at: c.created_at, expires_at: c.expires_at, revoked: c.revoked, description: c.description.unwrap_or_default(), }) .collect(); Ok(Response::new(ListCredentialsResponse { credentials: creds })) } async fn revoke_credential( &self, request: Request, ) -> Result, Status> { let req = request.into_inner(); let revoked = self .store .revoke(&req.access_key_id) .await .map_err(|e| Status::internal(format!("revoke: {}", e)))?; Ok(Response::new(RevokeCredentialResponse { success: revoked })) } } #[cfg(test)] mod tests { use super::*; use base64::engine::general_purpose::STANDARD; use iam_store::Backend; fn test_service() -> IamCredentialService { let backend = Arc::new(Backend::memory()); let store = Arc::new(CredentialStore::new(backend)); let master_key = [0x42u8; 32]; IamCredentialService::new(store, &master_key, "test-key").unwrap() } #[tokio::test] async fn create_and_get_roundtrip() { let svc = test_service(); let create = svc .create_s3_credential(Request::new(CreateS3CredentialRequest { principal_id: "p1".into(), description: "".into(), expires_at: None, })) .await .unwrap() .into_inner(); let get = svc .get_secret_key(Request::new(GetSecretKeyRequest { access_key_id: create.access_key_id.clone(), })) .await .unwrap() .into_inner(); let orig = STANDARD.decode(create.secret_key).unwrap(); let fetched = STANDARD.decode(get.secret_key).unwrap(); assert_eq!(orig, fetched); assert_eq!(get.principal_id, "p1"); } #[tokio::test] async fn list_filters_by_principal() { let svc = test_service(); let a = svc .create_s3_credential(Request::new(CreateS3CredentialRequest { principal_id: "pA".into(), description: "".into(), expires_at: None, })) .await .unwrap() .into_inner(); let _b = svc .create_s3_credential(Request::new(CreateS3CredentialRequest { principal_id: "pB".into(), description: "".into(), expires_at: None, })) .await .unwrap(); let list_a = svc .list_credentials(Request::new(ListCredentialsRequest { principal_id: "pA".into(), })) .await .unwrap() .into_inner(); assert_eq!(list_a.credentials.len(), 1); assert_eq!(list_a.credentials[0].access_key_id, a.access_key_id); } #[tokio::test] async fn revoke_blocks_get() { let svc = test_service(); let created = svc .create_s3_credential(Request::new(CreateS3CredentialRequest { principal_id: "p1".into(), description: "".into(), expires_at: None, })) .await .unwrap() .into_inner(); let revoke1 = svc .revoke_credential(Request::new(RevokeCredentialRequest { access_key_id: created.access_key_id.clone(), reason: "test".into(), })) .await .unwrap() .into_inner(); assert!(revoke1.success); let revoke2 = svc .revoke_credential(Request::new(RevokeCredentialRequest { access_key_id: created.access_key_id.clone(), reason: "again".into(), })) .await .unwrap() .into_inner(); assert!(!revoke2.success); let err = svc .get_secret_key(Request::new(GetSecretKeyRequest { access_key_id: created.access_key_id, })) .await .unwrap_err(); assert_eq!(err.code(), Status::permission_denied("").code()); } #[tokio::test] async fn expired_key_is_denied() { let svc = test_service(); // Manually insert an expired record let expired = CredentialRecord { access_key_id: "expired-ak".into(), principal_id: "p1".into(), created_at: now_ts(), expires_at: Some(now_ts() - 10), revoked: false, description: None, secret_hash: "hash".into(), secret_enc: STANDARD.encode(b"dead"), key_id: "k".into(), version: 1, kdf: Argon2Params { m_cost_kib: 19456, t_cost: 2, p_cost: 1, salt_b64: "c2FsdA==".into(), }, }; svc.store.put(&expired).await.unwrap(); let err = svc .get_secret_key(Request::new(GetSecretKeyRequest { access_key_id: "expired-ak".into(), })) .await .unwrap_err(); assert_eq!(err.code(), Status::permission_denied("").code()); } #[test] fn master_key_length_enforced() { let backend = Arc::new(Backend::memory()); let store = Arc::new(CredentialStore::new(backend)); let bad = IamCredentialService::new(store.clone(), &[0u8; 16], "k"); assert!(bad.is_err()); } }