photoncloud-monorepo/iam/crates/iam-server/src/main.rs

583 lines
19 KiB
Rust

//! IAM Server
//!
//! The main entry point for the IAM gRPC server.
mod config;
mod rest;
use std::sync::Arc;
use std::time::Duration;
use std::time::{SystemTime, UNIX_EPOCH};
use chainfire_client::Client as ChainFireClient;
use clap::Parser;
use metrics_exporter_prometheus::PrometheusBuilder;
use tonic::service::Interceptor;
use tonic::transport::{Certificate, Identity, Server, ServerTlsConfig};
use tonic::{metadata::MetadataMap, Request, Status};
use tonic_health::server::health_reporter;
use tracing::{info, warn};
use iam_api::{
iam_admin_server::IamAdminServer, iam_authz_server::IamAuthzServer,
iam_token_server::IamTokenServer, GatewayAuthServiceImpl, GatewayAuthServiceServer,
IamAdminService, IamAuthzService, IamTokenService,
};
use iam_authn::{InternalTokenConfig, InternalTokenService, SigningKey};
use iam_authz::{PolicyCache, PolicyCacheConfig, PolicyEvaluator};
use iam_store::{Backend, BackendConfig, BindingStore, PrincipalStore, RoleStore, TokenStore};
use config::{BackendKind, ServerConfig};
#[derive(Clone)]
struct AdminTokenInterceptor {
token: Option<Arc<String>>,
}
impl Interceptor for AdminTokenInterceptor {
fn call(&mut self, request: Request<()>) -> Result<Request<()>, Status> {
match &self.token {
Some(token) => {
if admin_token_valid(request.metadata(), token) {
Ok(request)
} else {
Err(Status::unauthenticated(
"missing or invalid IAM admin token",
))
}
}
None => Ok(request),
}
}
}
fn load_admin_token() -> Option<String> {
std::env::var("IAM_ADMIN_TOKEN")
.or_else(|_| std::env::var("PHOTON_IAM_ADMIN_TOKEN"))
.ok()
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
}
fn admin_token_valid(metadata: &MetadataMap, token: &str) -> bool {
if let Some(value) = metadata.get("x-iam-admin-token") {
if let Ok(raw) = value.to_str() {
if raw.trim() == token {
return true;
}
}
}
if let Some(value) = metadata.get("authorization") {
if let Ok(raw) = value.to_str() {
let raw = raw.trim();
if let Some(rest) = raw.strip_prefix("Bearer ") {
return rest.trim() == token;
}
if let Some(rest) = raw.strip_prefix("bearer ") {
return rest.trim() == token;
}
}
}
false
}
/// IAM Server
#[derive(Parser, Debug)]
#[command(name = "iam-server")]
#[command(about = "Identity and Access Management Server")]
struct Args {
/// Configuration file path
#[arg(short, long)]
config: Option<String>,
/// Listen address (overrides config)
#[arg(long)]
addr: Option<String>,
/// Log level (overrides config)
#[arg(long)]
log_level: Option<String>,
/// ChainFire endpoint for cluster coordination (overrides config)
#[arg(long)]
chainfire_endpoint: Option<String>,
/// Metrics port for Prometheus scraping (default: 9093)
#[arg(long, default_value = "9093")]
metrics_port: u16,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();
// Load configuration
let mut config = match &args.config {
Some(path) => ServerConfig::from_file(path)?,
None => ServerConfig::from_env()?,
};
// Apply CLI overrides
if let Some(addr) = args.addr {
config.server.addr = addr.parse()?;
}
if let Some(level) = args.log_level {
config.logging.level = level;
}
if let Some(endpoint) = args.chainfire_endpoint {
config.cluster.chainfire_endpoint = Some(endpoint);
}
// Initialize logging
init_logging(&config.logging.level);
// Initialize Prometheus metrics exporter
// Serves metrics at http://0.0.0.0:{metrics_port}/metrics
let metrics_addr = format!("0.0.0.0:{}", args.metrics_port);
let builder = PrometheusBuilder::new();
builder
.with_http_listener(metrics_addr.parse::<std::net::SocketAddr>()?)
.install()
.expect("Failed to install Prometheus metrics exporter");
info!(
"Prometheus metrics available at http://{}/metrics",
metrics_addr
);
// Register common metrics
metrics::describe_counter!(
"iam_authz_requests_total",
"Total number of authorization requests"
);
metrics::describe_counter!(
"iam_authz_allowed_total",
"Total number of allowed authorization requests"
);
metrics::describe_counter!(
"iam_authz_denied_total",
"Total number of denied authorization requests"
);
metrics::describe_counter!("iam_token_issued_total", "Total number of tokens issued");
metrics::describe_histogram!(
"iam_request_duration_seconds",
"Request duration in seconds"
);
info!("Starting IAM server on {}", config.server.addr);
if let Some(endpoint) = &config.cluster.chainfire_endpoint {
let normalized = normalize_chainfire_endpoint(endpoint);
info!(
"Cluster coordination enabled via ChainFire at {}",
normalized
);
let addr = config.server.addr.to_string();
tokio::spawn(async move {
if let Err(error) = register_chainfire_membership(&normalized, "iam", addr).await {
warn!(error = %error, "ChainFire membership registration failed");
}
});
}
// Create backend
let backend = create_backend(&config).await?;
let backend = Arc::new(backend);
// Create stores
let principal_store = Arc::new(PrincipalStore::new(backend.clone()));
let role_store = Arc::new(RoleStore::new(backend.clone()));
let binding_store = Arc::new(BindingStore::new(backend.clone()));
let token_store = Arc::new(TokenStore::new(backend.clone()));
// Initialize builtin roles
info!("Initializing builtin roles...");
role_store.init_builtin_roles().await?;
// Create policy cache
let cache_config = PolicyCacheConfig {
binding_ttl: Duration::from_secs(300),
role_ttl: Duration::from_secs(600),
max_binding_entries: 10000,
max_role_entries: 1000,
};
let cache = Arc::new(PolicyCache::new(cache_config));
// Create evaluator
let evaluator = Arc::new(PolicyEvaluator::new(
binding_store.clone(),
role_store.clone(),
cache,
));
// Create token service
let signing_key = if config.authn.internal_token.signing_key.is_empty() {
let allow_random = std::env::var("IAM_ALLOW_RANDOM_SIGNING_KEY")
.or_else(|_| std::env::var("PHOTON_IAM_ALLOW_RANDOM_SIGNING_KEY"))
.ok()
.map(|value| {
matches!(
value.trim().to_lowercase().as_str(),
"1" | "true" | "yes" | "y" | "on"
)
})
.unwrap_or(false);
if !allow_random {
return Err("No signing key configured. Set IAM_ALLOW_RANDOM_SIGNING_KEY=true for dev or configure authn.internal_token.signing_key.".into());
}
warn!("No signing key configured, generating random key (dev-only)");
SigningKey::generate("iam-key-1")
} else {
SigningKey::new(
"iam-key-1",
config.authn.internal_token.signing_key.as_bytes().to_vec(),
)
};
let token_config = InternalTokenConfig::new(signing_key, &config.authn.internal_token.issuer)
.with_default_ttl(Duration::from_secs(
config.authn.internal_token.default_ttl_seconds,
))
.with_max_ttl(Duration::from_secs(
config.authn.internal_token.max_ttl_seconds,
));
let token_service = Arc::new(InternalTokenService::new(token_config));
let admin_token = load_admin_token();
// Create gRPC services
let authz_service = IamAuthzService::new(evaluator.clone(), principal_store.clone());
let token_grpc_service = IamTokenService::new(
token_service.clone(),
principal_store.clone(),
token_store.clone(),
);
let gateway_auth_service = GatewayAuthServiceImpl::new(
token_service.clone(),
principal_store.clone(),
token_store.clone(),
evaluator.clone(),
);
let admin_service = IamAdminService::new(
principal_store.clone(),
role_store.clone(),
binding_store.clone(),
)
.with_evaluator(evaluator.clone());
let admin_interceptor = AdminTokenInterceptor {
token: admin_token.map(Arc::new),
};
if admin_interceptor.token.is_some() {
info!("IAM admin token authentication enabled");
} else {
warn!("IAM admin token not configured; admin API is unauthenticated");
}
let admin_server = IamAdminServer::with_interceptor(admin_service, admin_interceptor);
info!("IAM server ready, starting gRPC listeners...");
// Create health service (for K8s liveness/readiness probes)
// Uses grpc.health.v1.Health standard protocol
let (mut health_reporter, health_service) = health_reporter();
// Mark services as serving
health_reporter
.set_serving::<IamAuthzServer<IamAuthzService>>()
.await;
health_reporter
.set_serving::<IamTokenServer<IamTokenService>>()
.await;
health_reporter
.set_serving::<IamAdminServer<IamAdminService>>()
.await;
health_reporter
.set_serving::<GatewayAuthServiceServer<GatewayAuthServiceImpl>>()
.await;
// Spawn health monitoring task
let backend_for_health = backend.clone();
tokio::spawn(async move {
// Periodically check backend connectivity
loop {
tokio::time::sleep(Duration::from_secs(30)).await;
// Backend health check could be added here if Backend exposes a ping method
let _ = backend_for_health; // Keep reference alive
}
});
info!("Health check service enabled (grpc.health.v1.Health)");
// Configure TLS if enabled
let mut server = Server::builder();
if let Some(tls_config) = &config.server.tls {
info!("TLS enabled, loading certificates...");
let cert = tokio::fs::read(&tls_config.cert_file)
.await
.map_err(|e| format!("Failed to read cert file: {}", e))?;
let key = tokio::fs::read(&tls_config.key_file)
.await
.map_err(|e| format!("Failed to read key file: {}", e))?;
let server_identity = Identity::from_pem(cert, key);
let tls = if tls_config.require_client_cert {
info!("mTLS enabled, requiring client certificates");
let ca_cert = tokio::fs::read(
tls_config
.ca_file
.as_ref()
.ok_or("ca_file required when require_client_cert=true")?,
)
.await
.map_err(|e| format!("Failed to read CA file: {}", e))?;
let ca = Certificate::from_pem(ca_cert);
ServerTlsConfig::new()
.identity(server_identity)
.client_ca_root(ca)
} else {
info!("TLS-only mode, client certificates not required");
ServerTlsConfig::new().identity(server_identity)
};
server = server
.tls_config(tls)
.map_err(|e| format!("Failed to configure TLS: {}", e))?;
info!("TLS configuration applied successfully");
} else {
info!("TLS disabled, running in plain-text mode");
}
// gRPC server
let grpc_server = server
.add_service(health_service)
.add_service(IamAuthzServer::new(authz_service))
.add_service(IamTokenServer::new(token_grpc_service))
.add_service(GatewayAuthServiceServer::new(gateway_auth_service))
.add_service(admin_server)
.serve(config.server.addr);
// HTTP REST API server
let http_addr = config.server.http_addr;
let rest_state = rest::RestApiState {
server_addr: config.server.addr.to_string(),
tls_enabled: config.server.tls.is_some(),
};
let rest_app = rest::build_router(rest_state);
let http_listener = tokio::net::TcpListener::bind(&http_addr).await?;
info!(http_addr = %http_addr, "HTTP REST API server starting");
let http_server = async move {
axum::serve(http_listener, rest_app)
.await
.map_err(|e| format!("HTTP server error: {}", e))
};
// Run both servers concurrently
tokio::select! {
result = grpc_server => {
result?;
}
result = http_server => {
result?;
}
}
Ok(())
}
async fn create_backend(
config: &config::ServerConfig,
) -> Result<Backend, Box<dyn std::error::Error>> {
match config.store.backend {
BackendKind::Memory => {
let allow_memory = std::env::var("IAM_ALLOW_MEMORY_BACKEND")
.or_else(|_| std::env::var("PHOTON_IAM_ALLOW_MEMORY_BACKEND"))
.ok()
.map(|value| {
matches!(
value.trim().to_ascii_lowercase().as_str(),
"1" | "true" | "yes" | "on"
)
})
.unwrap_or(false);
if !allow_memory {
return Err(
"In-memory IAM backend is disabled. Use FlareDB backend, or set IAM_ALLOW_MEMORY_BACKEND=true for tests/dev only."
.into(),
);
}
info!("Using in-memory backend (explicitly enabled)");
Backend::new(BackendConfig::Memory)
.await
.map_err(|e| e.into())
}
BackendKind::FlareDb => {
let endpoint = config
.store
.flaredb_endpoint
.clone()
.ok_or("flaredb_endpoint required for flaredb backend")?;
let namespace = config
.store
.flaredb_namespace
.clone()
.unwrap_or_else(|| "iam".into());
let pd_endpoint = config
.cluster
.chainfire_endpoint
.as_deref()
.map(normalize_transport_addr)
.unwrap_or_else(|| endpoint.clone());
info!(
"Using FlareDB backend at {} via PD {} (namespace: {})",
endpoint, pd_endpoint, namespace
);
Backend::new(BackendConfig::FlareDb {
endpoint,
pd_endpoint,
namespace,
})
.await
.map_err(|e| e.into())
}
BackendKind::Postgres | BackendKind::Sqlite => {
let database_url = config
.store
.database_url
.as_deref()
.ok_or_else(|| {
format!(
"database_url is required when store.backend={}",
backend_kind_name(config.store.backend)
)
})?;
ensure_sql_backend_matches_url(config.store.backend, database_url)?;
info!(
"Using {} backend: {}",
backend_kind_name(config.store.backend),
database_url
);
Backend::new(BackendConfig::Sql {
database_url: database_url.to_string(),
single_node: config.store.single_node,
})
.await
.map_err(|e| e.into())
}
}
}
fn backend_kind_name(kind: BackendKind) -> &'static str {
match kind {
BackendKind::Memory => "memory",
BackendKind::FlareDb => "flaredb",
BackendKind::Postgres => "postgres",
BackendKind::Sqlite => "sqlite",
}
}
fn ensure_sql_backend_matches_url(
kind: BackendKind,
database_url: &str,
) -> Result<(), Box<dyn std::error::Error>> {
let normalized = database_url.trim().to_ascii_lowercase();
match kind {
BackendKind::Postgres => {
if normalized.starts_with("postgres://") || normalized.starts_with("postgresql://") {
Ok(())
} else {
Err("store.backend=postgres requires postgres:// or postgresql:// URL".into())
}
}
BackendKind::Sqlite => {
if normalized.starts_with("sqlite:") {
Ok(())
} else {
Err("store.backend=sqlite requires sqlite: URL".into())
}
}
BackendKind::FlareDb | BackendKind::Memory => Ok(()),
}
}
async fn register_chainfire_membership(
endpoint: &str,
service: &str,
addr: String,
) -> Result<(), Box<dyn std::error::Error>> {
let node_id =
std::env::var("HOSTNAME").unwrap_or_else(|_| format!("{}-{}", service, std::process::id()));
let ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let key = format!("/cluster/{}/members/{}", service, node_id);
let value = format!(r#"{{"addr":"{}","ts":{}}}"#, addr, ts);
let deadline = tokio::time::Instant::now() + Duration::from_secs(120);
let mut attempt = 0usize;
let mut last_error = String::new();
loop {
attempt += 1;
match ChainFireClient::connect(endpoint).await {
Ok(mut client) => match client.put_str(&key, &value).await {
Ok(_) => return Ok(()),
Err(error) => last_error = format!("put failed: {}", error),
},
Err(error) => last_error = format!("connect failed: {}", error),
}
if tokio::time::Instant::now() >= deadline {
break;
}
warn!(
attempt,
endpoint,
service,
error = %last_error,
"retrying ChainFire membership registration"
);
tokio::time::sleep(Duration::from_secs(2)).await;
}
Err(std::io::Error::other(format!(
"failed to register ChainFire membership for {} via {} after {} attempts: {}",
service, endpoint, attempt, last_error
))
.into())
}
fn normalize_chainfire_endpoint(endpoint: &str) -> String {
if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
endpoint.to_string()
} else {
format!("http://{}", endpoint)
}
}
fn normalize_transport_addr(endpoint: &str) -> String {
endpoint
.trim()
.trim_start_matches("http://")
.trim_start_matches("https://")
.trim_end_matches('/')
.to_string()
}
fn init_logging(level: &str) {
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(level));
tracing_subscriber::registry()
.with(filter)
.with(tracing_subscriber::fmt::layer())
.init();
}