photoncloud-monorepo/lightningstor/crates/lightningstor-server/src/s3/auth.rs

1500 lines
50 KiB
Rust

//! AWS Signature Version 4 authentication for S3 API
//!
//! Implements simplified SigV4 authentication compatible with AWS S3 SDKs and aws-cli.
//! Integrates with IAM for access key validation.
use crate::config::S3AuthConfig;
use crate::tenant::TenantContext;
use axum::{
body::{Body, Bytes},
extract::Request,
http::{HeaderMap, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use hmac::{Hmac, Mac};
use iam_api::proto::{iam_credential_client::IamCredentialClient, GetSecretKeyRequest};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration as StdDuration, Instant};
use tokio::sync::{Mutex, RwLock};
use tonic::{transport::Channel, Request as TonicRequest};
use tracing::{debug, warn};
use url::form_urlencoded;
type HmacSha256 = Hmac<Sha256>;
#[derive(Clone, Debug)]
pub(crate) struct VerifiedBodyBytes(pub Bytes);
#[derive(Clone, Debug)]
pub(crate) struct VerifiedPayloadHash(pub String);
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) struct VerifiedTenantContext(pub TenantContext);
fn should_buffer_auth_body(payload_hash_header: Option<&str>) -> bool {
payload_hash_header.is_none()
}
/// SigV4 authentication state
#[derive(Clone)]
pub struct AuthState {
/// IAM client for credential validation (optional for MVP)
iam_client: Option<Arc<RwLock<IamClient>>>,
/// Enable/disable auth (for testing)
pub enabled: bool,
/// AWS region for SigV4 (e.g., us-east-1)
aws_region: String,
/// AWS service name for SigV4 (e.g., s3)
aws_service: String,
/// Maximum request body size to buffer during auth verification
max_auth_body_bytes: usize,
}
pub struct IamClient {
mode: IamClientMode,
credential_cache: Arc<RwLock<HashMap<String, CachedCredential>>>,
cache_ttl: StdDuration,
}
enum IamClientMode {
Env {
credentials: std::collections::HashMap<String, String>,
default_org_id: Option<String>,
default_project_id: Option<String>,
},
Grpc {
endpoint: String,
channel: Arc<Mutex<Option<Channel>>>,
},
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) struct ResolvedCredential {
pub secret_key: String,
pub principal_id: String,
pub org_id: Option<String>,
pub project_id: Option<String>,
}
struct CachedCredential {
credential: ResolvedCredential,
cached_at: Instant,
}
impl IamClient {
/// Create a new IAM client. If an endpoint is supplied, use the IAM gRPC API.
pub fn new(iam_endpoint: Option<String>) -> Self {
Self::new_with_config(iam_endpoint, &S3AuthConfig::default())
}
pub fn new_with_config(iam_endpoint: Option<String>, config: &S3AuthConfig) -> Self {
let cache_ttl = StdDuration::from_secs(config.iam_cache_ttl_secs);
if let Some(endpoint) = iam_endpoint
.map(|value| normalize_iam_endpoint(&value))
.filter(|value| !value.is_empty())
{
return Self {
mode: IamClientMode::Grpc {
endpoint,
channel: Arc::new(Mutex::new(None)),
},
credential_cache: Arc::new(RwLock::new(HashMap::new())),
cache_ttl,
};
}
Self {
mode: IamClientMode::Env {
credentials: Self::load_env_credentials(),
default_org_id: config.default_org_id.clone(),
default_project_id: config.default_project_id.clone(),
},
credential_cache: Arc::new(RwLock::new(HashMap::new())),
cache_ttl,
}
}
/// Load credentials from environment variables for fallback/testing.
///
/// Supports two formats:
/// 1. Single credential: S3_ACCESS_KEY_ID + S3_SECRET_KEY
/// 2. Multiple credentials: S3_CREDENTIALS="key1:secret1,key2:secret2,..."
fn load_env_credentials() -> std::collections::HashMap<String, String> {
let mut credentials = std::collections::HashMap::new();
// Option 1: Multiple credentials via S3_CREDENTIALS
if let Ok(creds_str) = std::env::var("S3_CREDENTIALS") {
for pair in creds_str.split(',') {
if let Some((access_key, secret_key)) = pair.split_once(':') {
credentials
.insert(access_key.trim().to_string(), secret_key.trim().to_string());
} else {
warn!("Invalid S3_CREDENTIALS format for pair: {}", pair);
}
}
if !credentials.is_empty() {
debug!(
"Loaded {} S3 credential(s) from S3_CREDENTIALS",
credentials.len()
);
}
}
// Option 2: Single credential via separate env vars (legacy support)
if credentials.is_empty() {
if let (Ok(access_key_id), Ok(secret_key)) = (
std::env::var("S3_ACCESS_KEY_ID"),
std::env::var("S3_SECRET_KEY"),
) {
credentials.insert(access_key_id, secret_key);
debug!("Loaded S3 credentials from S3_ACCESS_KEY_ID/S3_SECRET_KEY");
}
}
if credentials.is_empty() {
warn!("No S3 credentials configured. Auth will reject all requests.");
warn!("Set S3_CREDENTIALS or S3_ACCESS_KEY_ID/S3_SECRET_KEY to enable access.");
}
credentials
}
#[cfg(test)]
fn env_credentials(&self) -> Option<&std::collections::HashMap<String, String>> {
match &self.mode {
IamClientMode::Env { credentials, .. } => Some(credentials),
IamClientMode::Grpc { .. } => None,
}
}
fn env_default_tenant(
default_org_id: Option<String>,
default_project_id: Option<String>,
) -> (Option<String>, Option<String>) {
(default_org_id, default_project_id)
}
/// Validate access key and resolve the credential context.
pub async fn get_credential(&self, access_key_id: &str) -> Result<ResolvedCredential, String> {
match &self.mode {
IamClientMode::Env {
credentials,
default_org_id,
default_project_id,
} => {
let secret_key = credentials
.get(access_key_id)
.cloned()
.ok_or_else(|| "Access key ID not found".to_string())?;
let (org_id, project_id) =
Self::env_default_tenant(default_org_id.clone(), default_project_id.clone());
Ok(ResolvedCredential {
secret_key,
principal_id: access_key_id.to_string(),
org_id,
project_id,
})
}
IamClientMode::Grpc { endpoint, channel } => {
if let Some(credential) = self.cached_credential(access_key_id).await {
return Ok(credential);
}
let response = self
.grpc_get_secret_key(endpoint, channel, access_key_id)
.await?;
let response = response.into_inner();
let credential = ResolvedCredential {
secret_key: response.secret_key,
principal_id: response.principal_id,
org_id: response.org_id,
project_id: response.project_id,
};
self.cache_credential(access_key_id, &credential).await;
Ok(credential)
}
}
}
async fn cached_credential(&self, access_key_id: &str) -> Option<ResolvedCredential> {
let cache = self.credential_cache.read().await;
cache.get(access_key_id).and_then(|entry| {
if entry.cached_at.elapsed() <= self.cache_ttl {
Some(entry.credential.clone())
} else {
None
}
})
}
async fn cache_credential(&self, access_key_id: &str, credential: &ResolvedCredential) {
let mut cache = self.credential_cache.write().await;
cache.insert(
access_key_id.to_string(),
CachedCredential {
credential: credential.clone(),
cached_at: Instant::now(),
},
);
}
async fn grpc_channel(
endpoint: &str,
channel: &Arc<Mutex<Option<Channel>>>,
) -> Result<Channel, String> {
let mut cached = channel.lock().await;
if let Some(existing) = cached.as_ref() {
return Ok(existing.clone());
}
let created = Channel::from_shared(endpoint.to_string())
.map_err(|e| format!("failed to parse IAM credential endpoint: {}", e))?
.connect()
.await
.map_err(|e| format!("failed to connect to IAM credential service: {}", e))?;
*cached = Some(created.clone());
Ok(created)
}
async fn invalidate_grpc_channel(channel: &Arc<Mutex<Option<Channel>>>) {
let mut cached = channel.lock().await;
*cached = None;
}
async fn grpc_get_secret_key(
&self,
endpoint: &str,
channel: &Arc<Mutex<Option<Channel>>>,
access_key_id: &str,
) -> Result<tonic::Response<iam_api::proto::GetSecretKeyResponse>, String> {
for attempt in 0..2 {
let grpc_channel = Self::grpc_channel(endpoint, channel).await?;
let mut client = IamCredentialClient::new(grpc_channel);
let mut request = TonicRequest::new(GetSecretKeyRequest {
access_key_id: access_key_id.to_string(),
});
if let Some(token) = iam_admin_token() {
if let Ok(value) = token.parse() {
request.metadata_mut().insert("x-iam-admin-token", value);
}
}
match client.get_secret_key(request).await {
Ok(response) => return Ok(response),
Err(status)
if attempt == 0
&& matches!(
status.code(),
tonic::Code::Unavailable
| tonic::Code::Cancelled
| tonic::Code::Unknown
| tonic::Code::DeadlineExceeded
| tonic::Code::Internal
) =>
{
Self::invalidate_grpc_channel(channel).await;
}
Err(status) => return Err(status.message().to_string()),
}
}
Err("IAM credential lookup exhausted retries".to_string())
}
}
fn normalize_iam_endpoint(endpoint: &str) -> String {
if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
endpoint.to_string()
} else {
format!("http://{}", endpoint)
}
}
fn iam_admin_token() -> Option<String> {
std::env::var("IAM_ADMIN_TOKEN")
.or_else(|_| std::env::var("PHOTON_IAM_ADMIN_TOKEN"))
.ok()
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
}
impl AuthState {
/// Create new auth state with IAM integration
pub fn new(iam_endpoint: Option<String>) -> Self {
Self::new_with_config(iam_endpoint, &S3AuthConfig::default())
}
pub fn new_with_config(iam_endpoint: Option<String>, config: &S3AuthConfig) -> Self {
let iam_client = Some(Arc::new(RwLock::new(IamClient::new_with_config(
iam_endpoint,
config,
))));
Self {
iam_client,
enabled: config.enabled,
aws_region: config.aws_region.clone(),
aws_service: "s3".to_string(),
max_auth_body_bytes: config.max_auth_body_bytes,
}
}
/// Create auth state with auth disabled (for testing)
pub fn disabled() -> Self {
Self {
iam_client: None,
enabled: false,
aws_region: "us-east-1".to_string(),
aws_service: "s3".to_string(),
max_auth_body_bytes: S3AuthConfig::default().max_auth_body_bytes,
}
}
}
/// SigV4 authentication middleware
pub async fn sigv4_auth_middleware(
auth_state: Arc<AuthState>,
// This is required to read the request body
mut request: Request,
next: Next,
) -> Response {
// Skip auth if disabled
if !auth_state.enabled {
debug!("S3 auth disabled, allowing request");
return next.run(request).await;
}
let headers = request.headers().clone();
let method = request.method().to_string();
let uri = request.uri().to_string();
// Extract Authorization header
let auth_header = match headers.get("authorization") {
Some(header) => match header.to_str() {
Ok(s) => s,
Err(_) => {
return error_response(
StatusCode::BAD_REQUEST,
"InvalidArgument",
"Invalid Authorization header encoding",
);
}
},
None => {
return error_response(
StatusCode::FORBIDDEN,
"AccessDenied",
"Authorization header required",
);
}
};
let (access_key_id, credential_scope, signed_headers_str, provided_signature) =
match parse_auth_header(auth_header) {
Ok(val) => val,
Err(e) => return error_response(StatusCode::BAD_REQUEST, "InvalidArgument", &e),
};
let amz_date = match headers.get("x-amz-date") {
Some(header) => match header.to_str() {
Ok(s) => s,
Err(_) => {
return error_response(
StatusCode::BAD_REQUEST,
"InvalidArgument",
"Invalid x-amz-date header encoding",
);
}
},
None => {
return error_response(
StatusCode::BAD_REQUEST,
"InvalidArgument",
"x-amz-date header required",
);
}
};
// Get secret key from IAM (or use dummy for MVP)
let credential = if let Some(ref iam) = auth_state.iam_client {
match iam.read().await.get_credential(&access_key_id).await {
Ok(credential) => credential,
Err(e) => {
warn!("IAM credential validation failed: {}", e);
return error_response(
StatusCode::FORBIDDEN,
"InvalidAccessKeyId",
"The AWS Access Key Id you provided does not exist in our records",
);
}
}
} else {
debug!("No IAM integration, using dummy secret key if IamClient wasn't initialized.");
ResolvedCredential {
secret_key: "dummy_secret_key_for_mvp".to_string(),
principal_id: access_key_id.clone(),
org_id: Some("default".to_string()),
project_id: Some("default".to_string()),
}
};
let secret_key = credential.secret_key.as_str();
let payload_hash_header = headers
.get("x-amz-content-sha256")
.and_then(|value| value.to_str().ok())
.filter(|value| !value.is_empty())
.map(str::to_string);
let should_buffer_body = should_buffer_auth_body(payload_hash_header.as_deref());
let body_bytes = if should_buffer_body {
let (parts, body) = request.into_parts();
let body_bytes = match axum::body::to_bytes(body, auth_state.max_auth_body_bytes).await {
Ok(b) => b,
Err(e) => {
return error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"InternalError",
&e.to_string(),
)
}
};
request = Request::from_parts(parts, Body::from(body_bytes.clone()));
request
.extensions_mut()
.insert(VerifiedBodyBytes(body_bytes.clone()));
body_bytes
} else {
if let Some(payload_hash) = payload_hash_header {
request
.extensions_mut()
.insert(VerifiedPayloadHash(payload_hash));
}
Bytes::new()
};
let (canonical_request, hashed_payload) =
match build_canonical_request(&method, &uri, &headers, &body_bytes, &signed_headers_str) {
Ok(val) => val,
Err(e) => {
return error_response(StatusCode::INTERNAL_SERVER_ERROR, "SignatureError", &e)
}
};
debug!(
method = %method,
uri = %uri,
headers = ?headers,
signed_headers_str = %signed_headers_str,
hashed_payload = %hashed_payload,
canonical_request = %canonical_request,
"SigV4 Canonical Request generated"
);
let string_to_sign = build_string_to_sign(amz_date, &credential_scope, &canonical_request);
debug!(
amz_date = %amz_date,
credential_scope = %credential_scope,
string_to_sign = %string_to_sign,
"SigV4 String to Sign generated"
);
let expected_signature = match compute_sigv4_signature(
secret_key,
&method,
&uri,
&headers,
amz_date,
&body_bytes,
&credential_scope,
&signed_headers_str,
&auth_state.aws_region,
&auth_state.aws_service,
) {
Ok(sig) => sig,
Err(e) => return error_response(StatusCode::INTERNAL_SERVER_ERROR, "SignatureError", &e),
};
// Compare signatures
if provided_signature != expected_signature {
debug!(
"Signature mismatch: provided={}, expected={}",
provided_signature, expected_signature
);
return error_response(
StatusCode::FORBIDDEN,
"SignatureDoesNotMatch",
"The request signature we calculated does not match the signature you provided",
);
}
match (credential.org_id, credential.project_id) {
(Some(org_id), Some(project_id)) => {
request
.extensions_mut()
.insert(VerifiedTenantContext(TenantContext { org_id, project_id }));
}
_ => {
return error_response(
StatusCode::FORBIDDEN,
"AccessDenied",
"S3 credential is missing tenant scope",
);
}
}
// Auth successful
debug!("SigV4 auth successful for access_key={}", access_key_id);
next.run(request).await
}
/// Parses the Authorization header to extract relevant SigV4 components.
/// Returns (AccessKeyId, CredentialScope, SignedHeaders, Signature)
fn parse_auth_header(auth_header: &str) -> Result<(String, String, String, String), String> {
let auth_header = auth_header.trim_start_matches("AWS4-HMAC-SHA256 ");
let parts: HashMap<&str, &str> = auth_header
.split(", ")
.filter_map(|s| s.split_once('='))
.collect();
let credential_str = parts
.get("Credential")
.ok_or("Credential not found in Authorization header")?;
let access_key_id = credential_str
.split('/')
.next()
.ok_or("Access Key ID not found in Credential")?;
let full_credential_scope = credential_str
.splitn(5, '/')
.skip(1)
.collect::<Vec<&str>>()
.join("/"); // Date/Region/Service/aws4_request
let signed_headers = parts
.get("SignedHeaders")
.ok_or("SignedHeaders not found in Authorization header")?;
let signature = parts
.get("Signature")
.ok_or("Signature not found in Authorization header")?;
Ok((
access_key_id.to_string(),
full_credential_scope,
signed_headers.to_string(),
signature.to_string(),
))
}
/// Compute the full AWS Signature Version 4.
#[allow(clippy::too_many_arguments)]
fn compute_sigv4_signature(
secret_key: &str,
method: &str,
uri: &str,
headers: &HeaderMap,
amz_date: &str,
body_bytes: &Bytes,
credential_scope: &str,
signed_headers_str: &str,
aws_region: &str,
aws_service: &str,
) -> Result<String, String> {
let (canonical_request, _hashed_payload) =
build_canonical_request(method, uri, headers, body_bytes, signed_headers_str)?;
let string_to_sign = build_string_to_sign(amz_date, credential_scope, &canonical_request);
let signing_key = get_signing_key(secret_key, amz_date, aws_region, aws_service)?;
let mut mac = HmacSha256::new_from_slice(&signing_key).map_err(|e| e.to_string())?;
mac.update(string_to_sign.as_bytes());
let result = mac.finalize();
Ok(hex::encode(result.into_bytes()))
}
/// Builds the Canonical Request as per AWS SigV4 specification.
fn build_canonical_request(
method: &str,
uri: &str,
headers: &HeaderMap,
body_bytes: &Bytes,
signed_headers_str: &str,
) -> Result<(String, String), String> {
// Canonical URI
let uri_parts: Vec<&str> = uri.split('?').collect();
let canonical_uri = url_encode_path(uri_parts[0]);
// Canonical Query String
let canonical_query_string = if uri_parts.len() > 1 {
let mut query_params: Vec<(String, String)> =
form_urlencoded::parse(uri_parts[1].as_bytes())
.into_owned()
.collect();
query_params.sort_by(|(k1, _), (k2, _)| k1.cmp(k2));
query_params
.into_iter()
.map(|(k, v)| format!("{}={}", url_encode(&k), url_encode(&v)))
.collect::<Vec<String>>()
.join("&")
} else {
"".to_string()
};
// Canonical Headers
let mut canonical_headers = String::new();
let mut sorted_signed_headers: Vec<String> = signed_headers_str
.split(';')
.map(|s| s.trim().to_lowercase())
.collect();
sorted_signed_headers.sort();
for header_name in sorted_signed_headers.iter() {
if let Some(header_value) = headers.get(header_name) {
let value_str = header_value
.to_str()
.map_err(|_| format!("Invalid header value for {}", header_name))?;
canonical_headers.push_str(&format!(
"{}:{}
",
header_name,
value_str.trim()
));
} else {
return Err(format!(
"Signed header '{}' not found in request",
header_name
));
}
}
// Hashed Payload
let hashed_payload = if signed_headers_str
.split(';')
.any(|header| header.trim().eq_ignore_ascii_case("x-amz-content-sha256"))
{
headers
.get("x-amz-content-sha256")
.and_then(|value| value.to_str().ok())
.filter(|value| !value.is_empty())
.map(str::to_string)
.unwrap_or_else(|| hex::encode(Sha256::digest(body_bytes)))
} else {
hex::encode(Sha256::digest(body_bytes))
};
let canonical_request = format!(
"{method}
{canonical_uri}
{canonical_query_string}
{canonical_headers}
{signed_headers_str}
{hashed_payload}",
method = method,
canonical_uri = canonical_uri,
canonical_query_string = canonical_query_string,
canonical_headers = canonical_headers,
signed_headers_str = signed_headers_str,
hashed_payload = hashed_payload
);
Ok((canonical_request, hashed_payload))
}
/// Builds the StringToSign as per AWS SigV4 specification.
fn build_string_to_sign(amz_date: &str, credential_scope: &str, canonical_request: &str) -> String {
let hashed_canonical_request = hex::encode(Sha256::digest(canonical_request.as_bytes()));
format!(
"AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{hashed_canonical_request}",
amz_date = amz_date,
credential_scope = credential_scope,
hashed_canonical_request = hashed_canonical_request
)
}
/// Derives the signing key as per AWS SigV4 specification.
fn get_signing_key(
secret_key: &str,
amz_date: &str,
aws_region: &str,
aws_service: &str,
) -> Result<Vec<u8>, String> {
let k_secret = format!("AWS4{}", secret_key);
let k_date = hmac_sha256(k_secret.as_bytes(), &amz_date[0..8])?; // Date (YYYYMMDD)
let k_region = hmac_sha256(&k_date, aws_region)?;
let k_service = hmac_sha256(&k_region, aws_service)?;
let k_signing = hmac_sha256(&k_service, "aws4_request")?;
Ok(k_signing)
}
/// Helper for HMAC-SHA256 operations.
fn hmac_sha256(key: &[u8], data: &str) -> Result<Vec<u8>, String> {
let mut mac = HmacSha256::new_from_slice(key).map_err(|e| e.to_string())?;
mac.update(data.as_bytes());
Ok(mac.finalize().into_bytes().to_vec())
}
/// URL-encodes a string for AWS SigV4 (RFC 3986 style).
/// Encodes everything except: A-Z, a-z, 0-9, -, _, ., ~
/// Uses uppercase hex digits for percent-encoding.
fn aws_uri_encode(s: &str, encode_slash: bool) -> String {
let mut encoded = String::new();
for byte in s.as_bytes() {
match byte {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
encoded.push(*byte as char);
}
b'/' if !encode_slash => {
encoded.push('/');
}
_ => {
encoded.push_str(&format!("%{:02X}", byte));
}
}
}
encoded
}
/// URL-encodes a string, specifically for query parameters.
/// For query parameters, we need to encode more characters including '/'
fn url_encode(s: &str) -> String {
aws_uri_encode(s, true)
}
/// URL-encodes a path according to AWS SigV4 specification.
/// Preserves slashes (/) but encodes all other special characters.
/// Returns "/" for empty paths as per AWS spec.
fn url_encode_path(s: &str) -> String {
if s.is_empty() || s == "/" {
return "/".to_string();
}
aws_uri_encode(s, false)
}
/// Create an S3 XML error response
fn error_response(status: StatusCode, code: &str, message: &str) -> Response {
let xml = format!(
r###"<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>{}</Code>
<Message>{}</Message>
</Error>"###,
code, message
);
Response::builder()
.status(status)
.header("Content-Type", "application/xml")
.body(Body::from(xml))
.unwrap()
.into_response()
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::HeaderValue;
use iam_api::proto::{
iam_credential_server::{IamCredential, IamCredentialServer},
CreateS3CredentialRequest, CreateS3CredentialResponse, Credential, GetSecretKeyResponse,
ListCredentialsRequest, ListCredentialsResponse, RevokeCredentialRequest,
RevokeCredentialResponse,
};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Mutex,
};
use tokio::net::TcpListener;
use tokio::time::{sleep, Duration};
use tonic::transport::Server;
use tonic::{Request as TonicRequest, Response as TonicResponse, Status};
static ENV_LOCK: Mutex<()> = Mutex::new(());
#[derive(Clone, Default)]
struct MockIamCredentialService {
secrets: Arc<HashMap<String, String>>,
get_secret_calls: Arc<AtomicUsize>,
}
#[tonic::async_trait]
impl IamCredential for MockIamCredentialService {
async fn create_s3_credential(
&self,
_request: TonicRequest<CreateS3CredentialRequest>,
) -> Result<TonicResponse<CreateS3CredentialResponse>, Status> {
Err(Status::unimplemented("not needed in test"))
}
async fn get_secret_key(
&self,
request: TonicRequest<GetSecretKeyRequest>,
) -> Result<TonicResponse<GetSecretKeyResponse>, Status> {
let access_key_id = request.into_inner().access_key_id;
self.get_secret_calls.fetch_add(1, Ordering::SeqCst);
let Some(secret_key) = self.secrets.get(&access_key_id) else {
return Err(Status::not_found("access key not found"));
};
Ok(TonicResponse::new(GetSecretKeyResponse {
secret_key: secret_key.clone(),
principal_id: "test-principal".to_string(),
expires_at: None,
org_id: Some("test-org".to_string()),
project_id: Some("test-project".to_string()),
principal_kind: iam_api::proto::PrincipalKind::ServiceAccount as i32,
}))
}
async fn list_credentials(
&self,
_request: TonicRequest<ListCredentialsRequest>,
) -> Result<TonicResponse<ListCredentialsResponse>, Status> {
Ok(TonicResponse::new(ListCredentialsResponse {
credentials: Vec::<Credential>::new(),
}))
}
async fn revoke_credential(
&self,
_request: TonicRequest<RevokeCredentialRequest>,
) -> Result<TonicResponse<RevokeCredentialResponse>, Status> {
Ok(TonicResponse::new(RevokeCredentialResponse {
success: true,
}))
}
}
async fn start_mock_iam(secrets: HashMap<String, String>) -> (SocketAddr, Arc<AtomicUsize>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let get_secret_calls = Arc::new(AtomicUsize::new(0));
let service = MockIamCredentialService {
secrets: Arc::new(secrets),
get_secret_calls: get_secret_calls.clone(),
};
drop(listener);
tokio::spawn(async move {
Server::builder()
.add_service(IamCredentialServer::new(service))
.serve(addr)
.await
.unwrap();
});
for _ in 0..20 {
if tokio::net::TcpStream::connect(addr).await.is_ok() {
return (addr, get_secret_calls);
}
sleep(Duration::from_millis(25)).await;
}
panic!("mock IAM server did not start on {}", addr);
}
#[tokio::test]
async fn test_parse_auth_header() {
let auth_header = "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20231201/us-east-1/s3/aws4_request, SignedHeaders=host;x-amz-date, Signature=abc123def456";
let (access_key, credential_scope, signed_headers, signature) =
parse_auth_header(auth_header).unwrap();
assert_eq!(access_key, "AKIAIOSFODNN7EXAMPLE");
assert_eq!(credential_scope, "20231201/us-east-1/s3/aws4_request");
assert_eq!(signed_headers, "host;x-amz-date");
assert_eq!(signature, "abc123def456");
}
#[test]
fn test_hmac_sha256() {
let key = b"key";
let data = "data";
// Verified with: echo -n "data" | openssl dgst -sha256 -mac hmac -macopt key:"key"
let expected =
hex::decode("5031fe3d989c6d1537a013fa6e739da23463fdaec3b70137d828e36ace221bd0")
.unwrap();
assert_eq!(hmac_sha256(key, data).unwrap(), expected);
}
#[test]
fn test_get_signing_key() {
let secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY";
let amz_date = "20231201T000000Z";
let aws_region = "us-east-1";
let aws_service = "s3";
let key = get_signing_key(secret_key, amz_date, aws_region, aws_service).unwrap();
// Expected signing key for these inputs can be found by external tools or AWS docs
// For now, just check it's not empty and has correct length
assert!(!key.is_empty());
assert_eq!(key.len(), 32); // SHA256 output length
}
#[test]
fn test_url_encode() {
assert_eq!(url_encode("foo"), "foo");
assert_eq!(url_encode("foo bar"), "foo%20bar");
assert_eq!(url_encode("foo/bar"), "foo%2Fbar");
}
#[test]
fn test_url_encode_path() {
assert_eq!(url_encode_path("/foo/bar"), "/foo/bar");
assert_eq!(url_encode_path("/foo bar/baz"), "/foo%20bar/baz");
assert_eq!(url_encode_path("/"), "/");
assert_eq!(url_encode_path(""), "/"); // Empty path should be normalized to /
// Test special characters that should be encoded
assert_eq!(url_encode_path("/my+bucket"), "/my%2Bbucket");
assert_eq!(url_encode_path("/my=bucket"), "/my%3Dbucket");
// Test unreserved characters that should NOT be encoded
assert_eq!(
url_encode_path("/my-bucket_test.file~123"),
"/my-bucket_test.file~123"
);
}
#[tokio::test]
async fn test_build_canonical_request() {
let method = "PUT";
let uri = "/mybucket/myobject?param1=value1&param2=value2";
let mut headers = HeaderMap::new();
headers.insert("Host", HeaderValue::from_static("example.com"));
headers.insert("Content-Type", HeaderValue::from_static("application/xml"));
headers.insert("x-amz-date", HeaderValue::from_static("20231201T000000Z"));
let body = Bytes::from("some_body");
let signed_headers = "content-type;host;x-amz-date";
let (canonical_request, hashed_payload) =
build_canonical_request(method, uri, &headers, &body, signed_headers).unwrap();
// Body hash verified with: echo -n "some_body" | sha256sum
let expected_body_hash = "fed42376ceefa4bb65ead687ec9738f6b2329fd78870aaf797bd7194da4228d3";
let expected_canonical_request = format!(
"PUT\n/mybucket/myobject\nparam1=value1&param2=value2\ncontent-type:application/xml\nhost:example.com\nx-amz-date:20231201T000000Z\n\ncontent-type;host;x-amz-date\n{}",
expected_body_hash
);
assert_eq!(canonical_request, expected_canonical_request);
assert_eq!(hashed_payload, expected_body_hash);
}
#[tokio::test]
async fn test_build_canonical_request_prefers_signed_payload_hash_header() {
let method = "PUT";
let uri = "/mybucket/myobject";
let mut headers = HeaderMap::new();
headers.insert("host", HeaderValue::from_static("example.com"));
headers.insert("x-amz-date", HeaderValue::from_static("20231201T000000Z"));
headers.insert(
"x-amz-content-sha256",
HeaderValue::from_static("signed-payload-hash"),
);
let body = Bytes::from("different-body");
let signed_headers = "host;x-amz-content-sha256;x-amz-date";
let (canonical_request, hashed_payload) =
build_canonical_request(method, uri, &headers, &body, signed_headers).unwrap();
assert!(canonical_request.ends_with("\nsigned-payload-hash"));
assert_eq!(hashed_payload, "signed-payload-hash");
}
#[test]
fn test_should_buffer_auth_body_only_when_hash_header_missing() {
assert!(should_buffer_auth_body(None));
assert!(!should_buffer_auth_body(Some("signed-payload-hash")));
assert!(!should_buffer_auth_body(Some("UNSIGNED-PAYLOAD")));
}
#[test]
fn test_build_string_to_sign() {
let amz_date = "20231201T000000Z";
let credential_scope = "20231201/us-east-1/s3/aws4_request";
let canonical_request = "some_canonical_request"; // Hashed below
let string_to_sign = build_string_to_sign(amz_date, credential_scope, canonical_request);
let hashed_canonical_request = hex::encode(Sha256::digest(canonical_request.as_bytes()));
let expected_string_to_sign = format!(
"AWS4-HMAC-SHA256\n{}\n{}\n{}",
amz_date, credential_scope, hashed_canonical_request
);
assert_eq!(string_to_sign, expected_string_to_sign);
}
#[test]
fn test_iam_client_multi_credentials() {
let _guard = ENV_LOCK.lock().unwrap();
// Test parsing S3_CREDENTIALS format
std::env::set_var("S3_CREDENTIALS", "key1:secret1,key2:secret2,key3:secret3");
let client = IamClient::new(None);
let credentials = client.env_credentials().unwrap();
assert_eq!(credentials.len(), 3);
assert_eq!(credentials.get("key1"), Some(&"secret1".to_string()));
assert_eq!(credentials.get("key2"), Some(&"secret2".to_string()));
assert_eq!(credentials.get("key3"), Some(&"secret3".to_string()));
std::env::remove_var("S3_CREDENTIALS");
}
#[test]
fn test_iam_client_single_credentials() {
let _guard = ENV_LOCK.lock().unwrap();
// Test legacy S3_ACCESS_KEY_ID/S3_SECRET_KEY format
std::env::remove_var("S3_CREDENTIALS");
std::env::set_var("S3_ACCESS_KEY_ID", "test_key");
std::env::set_var("S3_SECRET_KEY", "test_secret");
let client = IamClient::new(None);
let credentials = client.env_credentials().unwrap();
assert_eq!(credentials.len(), 1);
assert_eq!(
credentials.get("test_key"),
Some(&"test_secret".to_string())
);
std::env::remove_var("S3_ACCESS_KEY_ID");
std::env::remove_var("S3_SECRET_KEY");
}
#[tokio::test]
async fn test_iam_client_grpc_lookup() {
let (addr, _calls) = start_mock_iam(HashMap::from([(
"grpc_key".to_string(),
"grpc_secret".to_string(),
)]))
.await;
let client = IamClient::new(Some(addr.to_string()));
let credential = client.get_credential("grpc_key").await.unwrap();
assert_eq!(credential.secret_key, "grpc_secret");
assert_eq!(credential.org_id.as_deref(), Some("test-org"));
assert_eq!(credential.project_id.as_deref(), Some("test-project"));
assert_eq!(
client.get_credential("missing").await.unwrap_err(),
"access key not found"
);
}
#[tokio::test]
async fn test_iam_client_grpc_cache_reuses_secret() {
let (addr, calls) = start_mock_iam(HashMap::from([(
"grpc_key".to_string(),
"grpc_secret".to_string(),
)]))
.await;
let client = IamClient::new(Some(addr.to_string()));
assert_eq!(
client.get_credential("grpc_key").await.unwrap().secret_key,
"grpc_secret"
);
assert_eq!(
client.get_credential("grpc_key").await.unwrap().secret_key,
"grpc_secret"
);
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[test]
fn test_complete_sigv4_signature() {
// Test with AWS example credentials (from AWS docs)
let secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY";
let method = "GET";
let uri = "/";
let amz_date = "20150830T123600Z";
let credential_scope = "20150830/us-east-1/s3/aws4_request";
let signed_headers = "host;x-amz-date";
let mut headers = HeaderMap::new();
headers.insert(
"host",
HeaderValue::from_static("examplebucket.s3.amazonaws.com"),
);
headers.insert("x-amz-date", HeaderValue::from_static("20150830T123600Z"));
let body = Bytes::new(); // Empty body for GET
// Build canonical request
let (canonical_request, _) =
build_canonical_request(method, uri, &headers, &body, signed_headers).unwrap();
// Build string to sign
let _string_to_sign = build_string_to_sign(amz_date, credential_scope, &canonical_request);
// Compute signature
let signature = compute_sigv4_signature(
secret_key,
method,
uri,
&headers,
amz_date,
&body,
credential_scope,
signed_headers,
"us-east-1",
"s3",
)
.unwrap();
// Verify signature is deterministic (same inputs = same output)
let signature2 = compute_sigv4_signature(
secret_key,
method,
uri,
&headers,
amz_date,
&body,
credential_scope,
signed_headers,
"us-east-1",
"s3",
)
.unwrap();
assert_eq!(signature, signature2);
assert_eq!(signature.len(), 64); // SHA256 hex = 64 chars
}
// =============================================================================
// Security Tests
// =============================================================================
#[test]
fn test_security_invalid_auth_header_format() {
// Missing Credential field
let malformed1 = "AWS4-HMAC-SHA256 SignedHeaders=host, Signature=abc123";
assert!(parse_auth_header(malformed1).is_err());
// Missing SignedHeaders field
let malformed2 = "AWS4-HMAC-SHA256 Credential=KEY/scope, Signature=abc123";
assert!(parse_auth_header(malformed2).is_err());
// Missing Signature field
let malformed3 = "AWS4-HMAC-SHA256 Credential=KEY/scope, SignedHeaders=host";
assert!(parse_auth_header(malformed3).is_err());
// Wrong algorithm
let malformed4 = "AWS4-HMAC-SHA512 Credential=KEY/scope, SignedHeaders=host, Signature=abc";
assert!(parse_auth_header(malformed4).is_err());
// Empty string
assert!(parse_auth_header("").is_err());
// Random garbage
assert!(parse_auth_header("not-an-auth-header").is_err());
}
#[test]
fn test_security_signature_changes_with_secret_key() {
let method = "GET";
let uri = "/test-bucket/object";
let amz_date = "20231201T000000Z";
let credential_scope = "20231201/us-east-1/s3/aws4_request";
let signed_headers = "host;x-amz-date";
let mut headers = HeaderMap::new();
headers.insert("host", HeaderValue::from_static("s3.amazonaws.com"));
headers.insert("x-amz-date", HeaderValue::from_static("20231201T000000Z"));
let body = Bytes::new();
// Compute signature with first secret key
let sig1 = compute_sigv4_signature(
"secret1",
method,
uri,
&headers,
amz_date,
&body,
credential_scope,
signed_headers,
"us-east-1",
"s3",
)
.unwrap();
// Compute signature with different secret key
let sig2 = compute_sigv4_signature(
"secret2",
method,
uri,
&headers,
amz_date,
&body,
credential_scope,
signed_headers,
"us-east-1",
"s3",
)
.unwrap();
// Signatures MUST be different
assert_ne!(
sig1, sig2,
"Signatures should differ with different secret keys"
);
}
#[test]
fn test_security_signature_changes_with_body() {
let secret_key = "test-secret-key";
let method = "PUT";
let uri = "/test-bucket/object";
let amz_date = "20231201T000000Z";
let credential_scope = "20231201/us-east-1/s3/aws4_request";
let signed_headers = "host;x-amz-date";
let mut headers = HeaderMap::new();
headers.insert("host", HeaderValue::from_static("s3.amazonaws.com"));
headers.insert("x-amz-date", HeaderValue::from_static("20231201T000000Z"));
// Signature with body1
let body1 = Bytes::from("original content");
let sig1 = compute_sigv4_signature(
secret_key,
method,
uri,
&headers,
amz_date,
&body1,
credential_scope,
signed_headers,
"us-east-1",
"s3",
)
.unwrap();
// Signature with modified body
let body2 = Bytes::from("modified content");
let sig2 = compute_sigv4_signature(
secret_key,
method,
uri,
&headers,
amz_date,
&body2,
credential_scope,
signed_headers,
"us-east-1",
"s3",
)
.unwrap();
// Signatures MUST be different
assert_ne!(sig1, sig2, "Signatures should differ with different bodies");
}
#[test]
fn test_security_signature_changes_with_uri() {
let secret_key = "test-secret-key";
let method = "GET";
let amz_date = "20231201T000000Z";
let credential_scope = "20231201/us-east-1/s3/aws4_request";
let signed_headers = "host;x-amz-date";
let mut headers = HeaderMap::new();
headers.insert("host", HeaderValue::from_static("s3.amazonaws.com"));
headers.insert("x-amz-date", HeaderValue::from_static("20231201T000000Z"));
let body = Bytes::new();
// Signature with uri1
let sig1 = compute_sigv4_signature(
secret_key,
method,
"/test-bucket/object1",
&headers,
amz_date,
&body,
credential_scope,
signed_headers,
"us-east-1",
"s3",
)
.unwrap();
// Signature with different URI
let sig2 = compute_sigv4_signature(
secret_key,
method,
"/test-bucket/object2",
&headers,
amz_date,
&body,
credential_scope,
signed_headers,
"us-east-1",
"s3",
)
.unwrap();
// Signatures MUST be different
assert_ne!(sig1, sig2, "Signatures should differ with different URIs");
}
#[test]
fn test_security_signature_changes_with_headers() {
let secret_key = "test-secret-key";
let method = "GET";
let uri = "/test-bucket/object";
let amz_date = "20231201T000000Z";
let credential_scope = "20231201/us-east-1/s3/aws4_request";
let signed_headers = "host;x-amz-content-sha256;x-amz-date";
let mut headers1 = HeaderMap::new();
headers1.insert("host", HeaderValue::from_static("s3.amazonaws.com"));
headers1.insert("x-amz-date", HeaderValue::from_static("20231201T000000Z"));
headers1.insert("x-amz-content-sha256", HeaderValue::from_static("hash1"));
let mut headers2 = HeaderMap::new();
headers2.insert("host", HeaderValue::from_static("s3.amazonaws.com"));
headers2.insert("x-amz-date", HeaderValue::from_static("20231201T000000Z"));
headers2.insert("x-amz-content-sha256", HeaderValue::from_static("hash2"));
let body = Bytes::new();
let sig1 = compute_sigv4_signature(
secret_key,
method,
uri,
&headers1,
amz_date,
&body,
credential_scope,
signed_headers,
"us-east-1",
"s3",
)
.unwrap();
let sig2 = compute_sigv4_signature(
secret_key,
method,
uri,
&headers2,
amz_date,
&body,
credential_scope,
signed_headers,
"us-east-1",
"s3",
)
.unwrap();
// Signatures MUST be different
assert_ne!(
sig1, sig2,
"Signatures should differ with different header values"
);
}
#[test]
fn test_security_signature_changes_with_query_params() {
let secret_key = "test-secret-key";
let method = "GET";
let amz_date = "20231201T000000Z";
let credential_scope = "20231201/us-east-1/s3/aws4_request";
let signed_headers = "host;x-amz-date";
let mut headers = HeaderMap::new();
headers.insert("host", HeaderValue::from_static("s3.amazonaws.com"));
headers.insert("x-amz-date", HeaderValue::from_static("20231201T000000Z"));
let body = Bytes::new();
// URI with query param
let sig1 = compute_sigv4_signature(
secret_key,
method,
"/test-bucket/object?prefix=foo",
&headers,
amz_date,
&body,
credential_scope,
signed_headers,
"us-east-1",
"s3",
)
.unwrap();
// URI with different query param
let sig2 = compute_sigv4_signature(
secret_key,
method,
"/test-bucket/object?prefix=bar",
&headers,
amz_date,
&body,
credential_scope,
signed_headers,
"us-east-1",
"s3",
)
.unwrap();
// Signatures MUST be different
assert_ne!(
sig1, sig2,
"Signatures should differ with different query parameters"
);
}
#[test]
fn test_security_credential_lookup_unknown_key() {
let _guard = ENV_LOCK.lock().unwrap();
// Test that unknown access keys return the correct result
std::env::remove_var("S3_CREDENTIALS");
std::env::set_var("S3_ACCESS_KEY_ID", "known_key");
std::env::set_var("S3_SECRET_KEY", "known_secret");
let client = IamClient::new(None);
let credentials = client.env_credentials().unwrap();
// Known key should be found in credentials map
assert_eq!(
credentials.get("known_key"),
Some(&"known_secret".to_string())
);
// Unknown key should not be found
assert_eq!(credentials.get("unknown_key"), None);
std::env::remove_var("S3_ACCESS_KEY_ID");
std::env::remove_var("S3_SECRET_KEY");
}
#[test]
fn test_security_empty_credentials() {
let _guard = ENV_LOCK.lock().unwrap();
// Test that IamClient keeps credentials empty when none provided
std::env::remove_var("S3_CREDENTIALS");
std::env::remove_var("S3_ACCESS_KEY_ID");
std::env::remove_var("S3_SECRET_KEY");
let client = IamClient::new(None);
// No credentials configured
assert!(client.env_credentials().unwrap().is_empty());
}
#[test]
fn test_security_malformed_s3_credentials_env() {
let _guard = ENV_LOCK.lock().unwrap();
// Test that malformed S3_CREDENTIALS are handled gracefully
// Missing colon separator
std::env::set_var("S3_CREDENTIALS", "key1_secret1,key2:secret2");
let client = IamClient::new(None);
let credentials = client.env_credentials().unwrap();
// Should only parse the valid pair (key2:secret2)
assert_eq!(credentials.len(), 1);
assert!(credentials.contains_key("key2"));
// Empty pairs
std::env::set_var("S3_CREDENTIALS", "key1:secret1,,key2:secret2");
let client2 = IamClient::new(None);
// Should parse both valid pairs, skip empty
assert_eq!(client2.env_credentials().unwrap().len(), 2);
std::env::remove_var("S3_CREDENTIALS");
}
}