photoncloud-monorepo/iam/crates/iam-api/src/credential_service.rs

365 lines
12 KiB
Rust

use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use aes_gcm::{aead::Aead, Aes256Gcm, Key, KeyInit, Nonce};
use argon2::{password_hash::{PasswordHasher, SaltString}, Argon2};
use base64::{engine::general_purpose::STANDARD, Engine};
use rand_core::{OsRng, RngCore};
use tonic::{Request, Response, Status};
use iam_store::CredentialStore;
use iam_types::{Argon2Params, CredentialRecord};
use crate::proto::{
iam_credential_server::IamCredential, CreateS3CredentialRequest,
CreateS3CredentialResponse, Credential, GetSecretKeyRequest, GetSecretKeyResponse,
ListCredentialsRequest, ListCredentialsResponse, RevokeCredentialRequest,
RevokeCredentialResponse,
};
fn now_ts() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
pub struct IamCredentialService {
store: Arc<CredentialStore>,
cipher: Aes256Gcm,
key_id: String,
}
impl IamCredentialService {
pub fn new(store: Arc<CredentialStore>, master_key: &[u8], key_id: &str) -> Result<Self, Status> {
if master_key.len() != 32 {
return Err(Status::failed_precondition(
"IAM_CRED_MASTER_KEY must be 32 bytes",
));
}
let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(master_key));
Ok(Self {
store,
cipher,
key_id: key_id.to_string(),
})
}
fn generate_secret() -> (String, Vec<u8>) {
let raw = uuid::Uuid::new_v4().as_bytes().to_vec();
let secret_b64 = STANDARD.encode(&raw);
(secret_b64, raw)
}
fn hash_secret(raw: &[u8]) -> (String, Argon2Params) {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let hash = argon2
.hash_password(raw, &salt)
.expect("argon2 hash")
.to_string();
let params = Argon2Params {
m_cost_kib: argon2.params().m_cost(),
t_cost: argon2.params().t_cost(),
p_cost: argon2.params().p_cost(),
salt_b64: salt.to_string(),
};
(hash, params)
}
fn encrypt_secret(&self, raw: &[u8]) -> Result<String, Status> {
let mut nonce_bytes = [0u8; 12];
OsRng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = self
.cipher
.encrypt(nonce, raw)
.map_err(|e| Status::internal(format!("encrypt secret: {}", e)))?;
let mut combined = nonce_bytes.to_vec();
combined.extend_from_slice(&ciphertext);
Ok(STANDARD.encode(combined))
}
fn decrypt_secret(&self, enc_b64: &str) -> Result<Vec<u8>, Status> {
let data = STANDARD
.decode(enc_b64)
.map_err(|e| Status::internal(format!("invalid b64: {}", e)))?;
if data.len() < 12 {
return Err(Status::internal("ciphertext too short"));
}
let (nonce_bytes, ct) = data.split_at(12);
let nonce = Nonce::from_slice(nonce_bytes);
self.cipher
.decrypt(nonce, ct)
.map_err(|e| Status::internal(format!("decrypt failed: {}", e)))
}
}
#[tonic::async_trait]
impl IamCredential for IamCredentialService {
async fn create_s3_credential(
&self,
request: Request<CreateS3CredentialRequest>,
) -> Result<Response<CreateS3CredentialResponse>, Status> {
let req = request.into_inner();
let now = now_ts();
let (secret_b64, raw_secret) = Self::generate_secret();
let (hash, kdf) = Self::hash_secret(&raw_secret);
let secret_enc = self.encrypt_secret(&raw_secret)?;
let access_key_id = format!("ak_{}", uuid::Uuid::new_v4());
let record = CredentialRecord {
access_key_id: access_key_id.clone(),
principal_id: req.principal_id.clone(),
created_at: now,
expires_at: req.expires_at,
revoked: false,
description: if req.description.is_empty() {
None
} else {
Some(req.description)
},
secret_hash: hash,
secret_enc,
key_id: self.key_id.clone(),
version: 1,
kdf,
};
self.store
.put(&record)
.await
.map_err(|e| Status::internal(format!("store credential: {}", e)))?;
Ok(Response::new(CreateS3CredentialResponse {
access_key_id,
secret_key: secret_b64,
created_at: now,
expires_at: req.expires_at,
}))
}
async fn get_secret_key(
&self,
request: Request<GetSecretKeyRequest>,
) -> Result<Response<GetSecretKeyResponse>, Status> {
let req = request.into_inner();
let record = match self.store.get(&req.access_key_id).await {
Ok(Some((rec, _))) => rec,
Ok(None) => return Err(Status::not_found("access key not found")),
Err(e) => {
return Err(Status::internal(format!(
"failed to load credential: {}",
e
)))
}
};
if record.revoked {
return Err(Status::permission_denied("access key revoked"));
}
if let Some(exp) = record.expires_at {
if now_ts() > exp {
return Err(Status::permission_denied("access key expired"));
}
}
let secret = self.decrypt_secret(&record.secret_enc)?;
Ok(Response::new(GetSecretKeyResponse {
secret_key: STANDARD.encode(secret),
principal_id: record.principal_id,
expires_at: record.expires_at,
}))
}
async fn list_credentials(
&self,
request: Request<ListCredentialsRequest>,
) -> Result<Response<ListCredentialsResponse>, Status> {
let req = request.into_inner();
let items = self
.store
.list_for_principal(&req.principal_id, 1000)
.await
.map_err(|e| Status::internal(format!("list credentials: {}", e)))?;
let creds: Vec<Credential> = items
.into_iter()
.map(|c| Credential {
access_key_id: c.access_key_id,
principal_id: c.principal_id,
created_at: c.created_at,
expires_at: c.expires_at,
revoked: c.revoked,
description: c.description.unwrap_or_default(),
})
.collect();
Ok(Response::new(ListCredentialsResponse { credentials: creds }))
}
async fn revoke_credential(
&self,
request: Request<RevokeCredentialRequest>,
) -> Result<Response<RevokeCredentialResponse>, Status> {
let req = request.into_inner();
let revoked = self
.store
.revoke(&req.access_key_id)
.await
.map_err(|e| Status::internal(format!("revoke: {}", e)))?;
Ok(Response::new(RevokeCredentialResponse { success: revoked }))
}
}
#[cfg(test)]
mod tests {
use super::*;
use base64::engine::general_purpose::STANDARD;
use iam_store::Backend;
fn test_service() -> IamCredentialService {
let backend = Arc::new(Backend::memory());
let store = Arc::new(CredentialStore::new(backend));
let master_key = [0x42u8; 32];
IamCredentialService::new(store, &master_key, "test-key").unwrap()
}
#[tokio::test]
async fn create_and_get_roundtrip() {
let svc = test_service();
let create = svc
.create_s3_credential(Request::new(CreateS3CredentialRequest {
principal_id: "p1".into(),
description: "".into(),
expires_at: None,
}))
.await
.unwrap()
.into_inner();
let get = svc
.get_secret_key(Request::new(GetSecretKeyRequest {
access_key_id: create.access_key_id.clone(),
}))
.await
.unwrap()
.into_inner();
let orig = STANDARD.decode(create.secret_key).unwrap();
let fetched = STANDARD.decode(get.secret_key).unwrap();
assert_eq!(orig, fetched);
assert_eq!(get.principal_id, "p1");
}
#[tokio::test]
async fn list_filters_by_principal() {
let svc = test_service();
let a = svc
.create_s3_credential(Request::new(CreateS3CredentialRequest {
principal_id: "pA".into(),
description: "".into(),
expires_at: None,
}))
.await
.unwrap()
.into_inner();
let _b = svc
.create_s3_credential(Request::new(CreateS3CredentialRequest {
principal_id: "pB".into(),
description: "".into(),
expires_at: None,
}))
.await
.unwrap();
let list_a = svc
.list_credentials(Request::new(ListCredentialsRequest {
principal_id: "pA".into(),
}))
.await
.unwrap()
.into_inner();
assert_eq!(list_a.credentials.len(), 1);
assert_eq!(list_a.credentials[0].access_key_id, a.access_key_id);
}
#[tokio::test]
async fn revoke_blocks_get() {
let svc = test_service();
let created = svc
.create_s3_credential(Request::new(CreateS3CredentialRequest {
principal_id: "p1".into(),
description: "".into(),
expires_at: None,
}))
.await
.unwrap()
.into_inner();
let revoke1 = svc
.revoke_credential(Request::new(RevokeCredentialRequest {
access_key_id: created.access_key_id.clone(),
reason: "test".into(),
}))
.await
.unwrap()
.into_inner();
assert!(revoke1.success);
let revoke2 = svc
.revoke_credential(Request::new(RevokeCredentialRequest {
access_key_id: created.access_key_id.clone(),
reason: "again".into(),
}))
.await
.unwrap()
.into_inner();
assert!(!revoke2.success);
let err = svc
.get_secret_key(Request::new(GetSecretKeyRequest {
access_key_id: created.access_key_id,
}))
.await
.unwrap_err();
assert_eq!(err.code(), Status::permission_denied("").code());
}
#[tokio::test]
async fn expired_key_is_denied() {
let svc = test_service();
// Manually insert an expired record
let expired = CredentialRecord {
access_key_id: "expired-ak".into(),
principal_id: "p1".into(),
created_at: now_ts(),
expires_at: Some(now_ts() - 10),
revoked: false,
description: None,
secret_hash: "hash".into(),
secret_enc: STANDARD.encode(b"dead"),
key_id: "k".into(),
version: 1,
kdf: Argon2Params {
m_cost_kib: 19456,
t_cost: 2,
p_cost: 1,
salt_b64: "c2FsdA==".into(),
},
};
svc.store.put(&expired).await.unwrap();
let err = svc
.get_secret_key(Request::new(GetSecretKeyRequest {
access_key_id: "expired-ak".into(),
}))
.await
.unwrap_err();
assert_eq!(err.code(), Status::permission_denied("").code());
}
#[test]
fn master_key_length_enforced() {
let backend = Arc::new(Backend::memory());
let store = Arc::new(CredentialStore::new(backend));
let bad = IamCredentialService::new(store.clone(), &[0u8; 16], "k");
assert!(bad.is_err());
}
}