use std::collections::HashMap; use std::io; use std::net::SocketAddr; use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; use apigateway_api::proto::{ AuthorizeRequest, CreditCommitRequest, CreditReserveRequest, CreditRollbackRequest, }; use apigateway_api::{GatewayAuthServiceClient, GatewayCreditServiceClient}; use axum::{ body::{to_bytes, Body}, extract::{ConnectInfo, State}, http::{HeaderMap, Request, StatusCode, Uri}, response::Response, routing::{any, get}, Json, Router, }; use clap::Parser; use reqwest::{Client, Url}; use serde::{Deserialize, Serialize}; use tonic::transport::{Certificate, Channel, ClientTlsConfig, Endpoint, Identity}; use tonic::Request as TonicRequest; use tracing::{info, warn}; use tracing_subscriber::EnvFilter; use uuid::Uuid; const DEFAULT_REQUEST_ID_HEADER: &str = "x-request-id"; const PHOTON_AUTH_TOKEN_HEADER: &str = "x-photon-auth-token"; const DEFAULT_AUTH_TIMEOUT_MS: u64 = 500; const DEFAULT_CREDIT_TIMEOUT_MS: u64 = 500; const DEFAULT_UPSTREAM_TIMEOUT_MS: u64 = 10_000; const RESERVED_AUTH_HEADERS: [&str; 10] = [ "authorization", "x-photon-auth-token", "x-subject-id", "x-org-id", "x-project-id", "x-roles", "x-scopes", "x-iam-session-id", "x-iam-principal-kind", "x-iam-auth-method", ]; const AUTH_PROVIDER_BLOCK_HEADERS: [&str; 17] = [ "authorization", "x-photon-auth-token", "x-subject-id", "x-org-id", "x-project-id", "x-roles", "x-scopes", "proxy-authorization", "cookie", "set-cookie", "host", "connection", "upgrade", "keep-alive", "te", "trailer", "transfer-encoding", ]; #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] enum PolicyMode { Disabled, Optional, Required, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] enum CommitPolicy { Success, Always, Never, } fn default_policy_mode() -> PolicyMode { PolicyMode::Required } fn default_commit_policy() -> CommitPolicy { CommitPolicy::Success } fn default_credit_units() -> u64 { 1 } fn default_upstream_timeout_ms() -> u64 { DEFAULT_UPSTREAM_TIMEOUT_MS } #[derive(Debug, Clone, Serialize, Deserialize)] struct TlsConfig { #[serde(default)] ca_file: Option, #[serde(default)] cert_file: Option, #[serde(default)] key_file: Option, #[serde(default)] domain_name: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] struct AuthProviderConfig { name: String, #[serde(rename = "type")] provider_type: String, endpoint: String, #[serde(default)] timeout_ms: Option, #[serde(default)] tls: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] struct CreditProviderConfig { name: String, #[serde(rename = "type")] provider_type: String, endpoint: String, #[serde(default)] timeout_ms: Option, #[serde(default)] tls: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] struct RouteAuthConfig { provider: String, #[serde(default = "default_policy_mode")] mode: PolicyMode, #[serde(default)] fail_open: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] struct RouteCreditConfig { provider: String, #[serde(default = "default_policy_mode")] mode: PolicyMode, #[serde(default = "default_credit_units")] units: u64, #[serde(default)] fail_open: bool, #[serde(default = "default_commit_policy")] commit_on: CommitPolicy, #[serde(default)] allow_header_subject: bool, #[serde(default)] attributes: HashMap, } #[derive(Debug, Clone, Serialize, Deserialize)] struct RouteConfig { name: String, path_prefix: String, upstream: String, #[serde(default)] strip_prefix: bool, #[serde(default)] timeout_ms: Option, #[serde(default)] auth: Option, #[serde(default)] credit: Option, } #[derive(Clone)] struct Route { config: RouteConfig, upstream: Url, upstream_base_path: String, } #[derive(Debug, Clone, Serialize, Deserialize)] struct ServerConfig { #[serde(default = "default_http_addr")] http_addr: SocketAddr, #[serde(default = "default_log_level")] log_level: String, #[serde(default = "default_max_body_bytes")] max_body_bytes: usize, #[serde(default = "default_max_response_bytes")] max_response_bytes: usize, #[serde(default = "default_upstream_timeout_ms")] upstream_timeout_ms: u64, #[serde(default)] trust_forwarded_headers: bool, #[serde(default)] auth_providers: Vec, #[serde(default)] credit_providers: Vec, #[serde(default)] routes: Vec, } impl Default for ServerConfig { fn default() -> Self { Self { http_addr: default_http_addr(), log_level: default_log_level(), max_body_bytes: default_max_body_bytes(), max_response_bytes: default_max_response_bytes(), upstream_timeout_ms: default_upstream_timeout_ms(), trust_forwarded_headers: false, auth_providers: Vec::new(), credit_providers: Vec::new(), routes: Vec::new(), } } } #[derive(Debug, Parser)] #[command(author, version, about)] struct Args { /// Configuration file path #[arg(short, long, default_value = "apigateway.toml")] config: PathBuf, /// HTTP listen address (overrides config) #[arg(long)] http_addr: Option, /// Log level (overrides config) #[arg(short, long)] log_level: Option, } #[derive(Clone)] struct ServerState { routes: Vec, client: Client, upstream_timeout: Duration, max_body_bytes: usize, max_response_bytes: usize, auth_providers: HashMap, credit_providers: HashMap, trust_forwarded_headers: bool, } #[derive(Clone)] struct GrpcAuthProvider { channel: Channel, timeout: Duration, } #[derive(Clone)] enum AuthProvider { Grpc(GrpcAuthProvider), } #[derive(Clone)] struct GrpcCreditProvider { channel: Channel, timeout: Duration, } #[derive(Clone)] enum CreditProvider { Grpc(GrpcCreditProvider), } #[derive(Clone, Debug)] struct SubjectInfo { subject_id: String, org_id: String, project_id: String, roles: Vec, scopes: Vec, } #[derive(Clone, Debug)] struct CreditSubject { subject_id: String, org_id: String, project_id: String, } #[derive(Clone, Debug)] struct AuthDecision { allow: bool, subject: Option, headers: HashMap, reason: Option, } #[derive(Clone, Debug)] struct CreditDecision { allow: bool, reservation_id: String, reason: Option, } #[derive(Clone, Debug, Default)] struct AuthOutcome { subject: Option, headers: HashMap, } #[derive(Clone, Debug)] struct CreditReservation { provider: String, reservation_id: String, } #[derive(Clone, Debug)] struct RequestContext { request_id: String, method: String, path: String, raw_query: String, headers: HashMap, client_ip: String, route_name: String, } fn default_http_addr() -> SocketAddr { "127.0.0.1:8080" .parse() .expect("invalid default HTTP address") } fn default_log_level() -> String { "info".to_string() } fn default_max_body_bytes() -> usize { 16 * 1024 * 1024 } fn default_max_response_bytes() -> usize { default_max_body_bytes() } #[tokio::main] async fn main() -> Result<(), Box> { let args = Args::parse(); let mut used_default_config = false; let mut config = if args.config.exists() { let contents = tokio::fs::read_to_string(&args.config).await?; toml::from_str(&contents)? } else { used_default_config = true; ServerConfig::default() }; if let Some(http_addr) = args.http_addr { config.http_addr = http_addr.parse()?; } if let Some(log_level) = args.log_level { config.log_level = log_level; } tracing_subscriber::fmt() .with_env_filter( EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&config.log_level)), ) .init(); if used_default_config { info!("Config file not found: {}, using defaults", args.config.display()); } let routes = build_routes(config.routes)?; let auth_providers = build_auth_providers(config.auth_providers).await?; let credit_providers = build_credit_providers(config.credit_providers).await?; let upstream_timeout = Duration::from_millis(config.upstream_timeout_ms); let client = Client::builder().build()?; info!("Starting API gateway"); info!(" HTTP: {}", config.http_addr); info!(" Max body bytes: {}", config.max_body_bytes); info!(" Max response bytes: {}", config.max_response_bytes); if !routes.is_empty() { info!("Configured {} routes", routes.len()); } else { warn!("No routes configured; proxy will return 404s"); } if auth_providers.is_empty() { warn!("No auth providers configured"); } if credit_providers.is_empty() { warn!("No credit providers configured"); } let state = Arc::new(ServerState { routes, client, upstream_timeout, max_body_bytes: config.max_body_bytes, max_response_bytes: config.max_response_bytes, auth_providers, credit_providers, trust_forwarded_headers: config.trust_forwarded_headers, }); let app = Router::new() .route("/", get(index)) .route("/health", get(health)) .route("/routes", get(list_routes)) .route("/*path", any(proxy)) .with_state(state); let listener = tokio::net::TcpListener::bind(config.http_addr).await?; axum::serve(listener, app.into_make_service_with_connect_info::()).await?; Ok(()) } async fn index() -> &'static str { "apigateway-server running" } async fn health() -> Json { Json(serde_json::json!({"status": "ok"})) } async fn list_routes(State(state): State>) -> Json> { Json(state.routes.iter().map(|route| route.config.clone()).collect()) } async fn proxy( State(state): State>, ConnectInfo(remote_addr): ConnectInfo, request: Request, ) -> Result, StatusCode> { let path = request.uri().path(); let route = match_route(&state.routes, path) .ok_or(StatusCode::NOT_FOUND)? .clone(); let request_id = extract_request_id(request.headers()); let context = RequestContext { request_id: request_id.clone(), method: request.method().to_string(), path: request.uri().path().to_string(), raw_query: request.uri().query().unwrap_or("").to_string(), headers: headers_to_map(request.headers()), client_ip: extract_client_ip( request.headers(), remote_addr, state.trust_forwarded_headers, ), route_name: route.config.name.clone(), }; let auth_token = extract_auth_token(request.headers()); let forward_client_auth_headers = route.config.auth.is_none(); let auth_outcome = enforce_auth(&state, &route, &context, auth_token).await?; let credit_reservation = enforce_credit(&state, &route, &context, auth_outcome.subject.as_ref()).await?; let target_url = build_upstream_url(&route, request.uri())?; let request_timeout = Duration::from_millis(route.config.timeout_ms.unwrap_or(state.upstream_timeout.as_millis() as u64)); let mut builder = state .client .request(request.method().clone(), target_url) .timeout(request_timeout); for (name, value) in request.headers().iter() { if name == axum::http::header::HOST || name == axum::http::header::CONNECTION { continue; } if is_reserved_auth_header(name) { if forward_client_auth_headers && should_preserve_client_auth_header(name.as_str()) { builder = builder.header(name, value); } continue; } builder = builder.header(name, value); } builder = builder.header(DEFAULT_REQUEST_ID_HEADER, request_id.clone()); builder = apply_auth_headers(builder, &auth_outcome); let body_bytes = to_bytes(request.into_body(), state.max_body_bytes) .await .map_err(|_| StatusCode::PAYLOAD_TOO_LARGE)?; let response = match builder.body(body_bytes).send().await { Ok(response) => response, Err(_) => { finalize_credit_abort(&state, &route, credit_reservation).await; return Err(StatusCode::BAD_GATEWAY); } }; let status = response.status(); if let Some(content_length) = response.content_length() { if state.max_response_bytes > 0 && content_length as usize > state.max_response_bytes { finalize_credit_abort(&state, &route, credit_reservation).await; return Err(StatusCode::PAYLOAD_TOO_LARGE); } } let mut response_builder = Response::builder().status(status); let headers = response_builder .headers_mut() .ok_or(StatusCode::BAD_GATEWAY)?; for (name, value) in response.headers().iter() { if name == axum::http::header::CONNECTION { continue; } headers.insert(name, value.clone()); } let body = match response.bytes().await { Ok(body) => body, Err(_) => { finalize_credit_abort(&state, &route, credit_reservation).await; return Err(StatusCode::BAD_GATEWAY); } }; if state.max_response_bytes > 0 && body.len() > state.max_response_bytes { finalize_credit_abort(&state, &route, credit_reservation).await; return Err(StatusCode::PAYLOAD_TOO_LARGE); } finalize_credit(&state, &route, credit_reservation, status).await; response_builder .body(Body::from(body)) .map_err(|_| StatusCode::BAD_GATEWAY) } async fn enforce_auth( state: &ServerState, route: &Route, context: &RequestContext, token: Option, ) -> Result { let Some(auth_cfg) = &route.config.auth else { return Ok(AuthOutcome::default()); }; if auth_cfg.mode == PolicyMode::Disabled { return Ok(AuthOutcome::default()); } let decision = authorize_request(state, auth_cfg, context, token).await; apply_auth_mode(auth_cfg.mode, auth_cfg.fail_open, decision) } fn apply_auth_mode( mode: PolicyMode, fail_open: bool, decision: Result, ) -> Result { match mode { PolicyMode::Disabled => Ok(AuthOutcome::default()), PolicyMode::Optional => match decision { Ok(decision) if decision.allow => Ok(AuthOutcome { subject: decision.subject, headers: decision.headers, }), Ok(decision) => { if let Some(reason) = decision.reason { warn!("Auth denied (optional mode): {}", reason); } Ok(AuthOutcome::default()) } Err(err) => { warn!("Auth provider error (optional mode): {}", err); Ok(AuthOutcome::default()) } }, PolicyMode::Required => match decision { Ok(decision) if decision.allow => Ok(AuthOutcome { subject: decision.subject, headers: decision.headers, }), Ok(decision) => { if let Some(reason) = decision.reason { warn!("Auth denied (required mode): {}", reason); } Err(StatusCode::FORBIDDEN) } Err(err) => { warn!("Auth provider error (required mode): {}", err); if fail_open { Ok(AuthOutcome::default()) } else { Err(StatusCode::BAD_GATEWAY) } } }, } } async fn enforce_credit( state: &ServerState, route: &Route, context: &RequestContext, subject: Option<&SubjectInfo>, ) -> Result, StatusCode> { let Some(credit_cfg) = &route.config.credit else { return Ok(None); }; if credit_cfg.mode == PolicyMode::Disabled { return Ok(None); } let credit_subject = resolve_credit_subject(context, subject, credit_cfg.allow_header_subject); if credit_subject.is_none() { if credit_cfg.mode == PolicyMode::Required { return Err(StatusCode::UNAUTHORIZED); } warn!("Credit skipped: missing org/project scope"); return Ok(None); } let decision = reserve_credit( state, credit_cfg, context, credit_subject.as_ref().expect("credit subject resolved"), ) .await; apply_credit_mode(credit_cfg.mode, credit_cfg.fail_open, decision) .map(|decision| { decision.map(|decision| CreditReservation { provider: credit_cfg.provider.clone(), reservation_id: decision.reservation_id, }) }) } fn apply_credit_mode( mode: PolicyMode, fail_open: bool, decision: Result, ) -> Result, StatusCode> { match mode { PolicyMode::Disabled => Ok(None), PolicyMode::Optional => match decision { Ok(decision) if decision.allow => Ok(Some(decision)), Ok(decision) => { if let Some(reason) = decision.reason { warn!("Credit denied (optional mode): {}", reason); } Ok(None) } Err(err) => { warn!("Credit provider error (optional mode): {}", err); Ok(None) } }, PolicyMode::Required => match decision { Ok(decision) if decision.allow => Ok(Some(decision)), Ok(decision) => { if let Some(reason) = decision.reason { warn!("Credit denied (required mode): {}", reason); } Err(StatusCode::PAYMENT_REQUIRED) } Err(err) => { warn!("Credit provider error (required mode): {}", err); if fail_open { Ok(None) } else { Err(StatusCode::BAD_GATEWAY) } } }, } } async fn authorize_request( state: &ServerState, auth_cfg: &RouteAuthConfig, context: &RequestContext, token: Option, ) -> Result { let provider = state .auth_providers .get(&auth_cfg.provider) .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; match provider { AuthProvider::Grpc(provider) => { let mut client = GatewayAuthServiceClient::new(provider.channel.clone()); let mut request = TonicRequest::new(AuthorizeRequest { request_id: context.request_id.clone(), token: token.unwrap_or_default(), method: context.method.clone(), path: context.path.clone(), raw_query: context.raw_query.clone(), headers: context.headers.clone(), client_ip: context.client_ip.clone(), route_name: context.route_name.clone(), }); request.set_timeout(provider.timeout); let response = client .authorize(request) .await .map_err(|_| StatusCode::BAD_GATEWAY)? .into_inner(); let subject = response.subject.map(|subject| SubjectInfo { subject_id: subject.subject_id, org_id: subject.org_id, project_id: subject.project_id, roles: subject.roles, scopes: subject.scopes, }); Ok(AuthDecision { allow: response.allow, subject, headers: response.headers, reason: if response.reason.is_empty() { None } else { Some(response.reason) }, }) } } } fn resolve_credit_subject( context: &RequestContext, subject: Option<&SubjectInfo>, allow_header_subject: bool, ) -> Option { if let Some(subject) = subject { return Some(CreditSubject { subject_id: subject.subject_id.clone(), org_id: subject.org_id.clone(), project_id: subject.project_id.clone(), }); } if !allow_header_subject { return None; } let org_id = context.headers.get("x-org-id")?.trim(); let project_id = context.headers.get("x-project-id")?.trim(); if org_id.is_empty() || project_id.is_empty() { return None; } let subject_id = context .headers .get("x-subject-id") .map(|value| value.trim().to_string()) .unwrap_or_default(); Some(CreditSubject { subject_id, org_id: org_id.to_string(), project_id: project_id.to_string(), }) } async fn reserve_credit( state: &ServerState, credit_cfg: &RouteCreditConfig, context: &RequestContext, credit_subject: &CreditSubject, ) -> Result { let provider = state .credit_providers .get(&credit_cfg.provider) .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; let subject_id = credit_subject.subject_id.clone(); let org_id = credit_subject.org_id.clone(); let project_id = credit_subject.project_id.clone(); match provider { CreditProvider::Grpc(provider) => { let mut client = GatewayCreditServiceClient::new(provider.channel.clone()); let mut request = TonicRequest::new(CreditReserveRequest { request_id: context.request_id.clone(), subject_id, org_id, project_id, route_name: context.route_name.clone(), method: context.method.clone(), path: context.path.clone(), raw_query: context.raw_query.clone(), units: credit_cfg.units, attributes: credit_cfg.attributes.clone(), }); request.set_timeout(provider.timeout); let response = client .reserve(request) .await .map_err(|_| StatusCode::BAD_GATEWAY)? .into_inner(); Ok(CreditDecision { allow: response.allow, reservation_id: response.reservation_id, reason: if response.reason.is_empty() { None } else { Some(response.reason) }, }) } } } async fn finalize_credit( state: &ServerState, route: &Route, reservation: Option, status: reqwest::StatusCode, ) { let Some(credit_cfg) = &route.config.credit else { return; }; let Some(reservation) = reservation else { return; }; match credit_cfg.commit_on { CommitPolicy::Never => return, CommitPolicy::Always => { if let Err(err) = commit_credit(state, credit_cfg, &reservation).await { warn!("Failed to commit credit reservation {}: {}", reservation.reservation_id, err); } } CommitPolicy::Success => { if status.is_success() || status.is_redirection() { if let Err(err) = commit_credit(state, credit_cfg, &reservation).await { warn!("Failed to commit credit reservation {}: {}", reservation.reservation_id, err); } } else if let Err(err) = rollback_credit(state, credit_cfg, &reservation).await { warn!( "Failed to rollback credit reservation {}: {}", reservation.reservation_id, err ); } } } } async fn finalize_credit_abort( state: &ServerState, route: &Route, reservation: Option, ) { let Some(credit_cfg) = &route.config.credit else { return; }; let Some(reservation) = reservation else { return; }; if credit_cfg.commit_on == CommitPolicy::Never { return; } if let Err(err) = rollback_credit(state, credit_cfg, &reservation).await { warn!( "Failed to rollback credit reservation {} after delivery failure: {}", reservation.reservation_id, err ); } } async fn commit_credit( state: &ServerState, credit_cfg: &RouteCreditConfig, reservation: &CreditReservation, ) -> Result<(), StatusCode> { let provider = state .credit_providers .get(&reservation.provider) .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; match provider { CreditProvider::Grpc(provider) => { let mut client = GatewayCreditServiceClient::new(provider.channel.clone()); let mut request = TonicRequest::new(CreditCommitRequest { reservation_id: reservation.reservation_id.clone(), units: credit_cfg.units, }); request.set_timeout(provider.timeout); let response = client .commit(request) .await .map_err(|_| StatusCode::BAD_GATEWAY)? .into_inner(); if response.success { Ok(()) } else { Err(StatusCode::BAD_GATEWAY) } } } } async fn rollback_credit( state: &ServerState, _credit_cfg: &RouteCreditConfig, reservation: &CreditReservation, ) -> Result<(), StatusCode> { let provider = state .credit_providers .get(&reservation.provider) .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; match provider { CreditProvider::Grpc(provider) => { let mut client = GatewayCreditServiceClient::new(provider.channel.clone()); let mut request = TonicRequest::new(CreditRollbackRequest { reservation_id: reservation.reservation_id.clone(), }); request.set_timeout(provider.timeout); let response = client .rollback(request) .await .map_err(|_| StatusCode::BAD_GATEWAY)? .into_inner(); if response.success { Ok(()) } else { Err(StatusCode::BAD_GATEWAY) } } } } fn apply_auth_headers( mut builder: reqwest::RequestBuilder, outcome: &AuthOutcome, ) -> reqwest::RequestBuilder { for (key, value) in &outcome.headers { if !should_forward_auth_header(key) { continue; } builder = builder.header(key, value); } if let Some(subject) = &outcome.subject { builder = builder .header("x-subject-id", &subject.subject_id) .header("x-org-id", &subject.org_id) .header("x-project-id", &subject.project_id); if !subject.roles.is_empty() { builder = builder.header("x-roles", subject.roles.join(",")); } if !subject.scopes.is_empty() { builder = builder.header("x-scopes", subject.scopes.join(",")); } } builder } async fn build_client_tls_config( tls: &Option, ) -> Result, Box> { let Some(tls) = tls else { return Ok(None); }; let mut tls_config = ClientTlsConfig::new(); if let Some(ca_file) = &tls.ca_file { let ca = tokio::fs::read(ca_file).await?; tls_config = tls_config.ca_certificate(Certificate::from_pem(ca)); } match (&tls.cert_file, &tls.key_file) { (Some(cert_file), Some(key_file)) => { let cert = tokio::fs::read(cert_file).await?; let key = tokio::fs::read(key_file).await?; tls_config = tls_config.identity(Identity::from_pem(cert, key)); } (None, None) => {} _ => { return Err(config_error("tls requires both cert_file and key_file").into()); } } if let Some(domain) = &tls.domain_name { tls_config = tls_config.domain_name(domain); } Ok(Some(tls_config)) } async fn build_auth_providers( configs: Vec, ) -> Result, Box> { let mut providers = HashMap::new(); for config in configs { let provider_type = normalize_name(&config.provider_type); if providers.contains_key(&config.name) { return Err(config_error(format!( "duplicate auth provider name {}", config.name )) .into()); } match provider_type.as_str() { "grpc" => { let mut endpoint = Endpoint::from_shared(config.endpoint.clone())? .connect_timeout(Duration::from_millis( config.timeout_ms.unwrap_or(DEFAULT_AUTH_TIMEOUT_MS), )) .timeout(Duration::from_millis( config.timeout_ms.unwrap_or(DEFAULT_AUTH_TIMEOUT_MS), )); if let Some(tls) = build_client_tls_config(&config.tls).await? { endpoint = endpoint.tls_config(tls)?; } let channel = endpoint.connect().await?; let timeout = Duration::from_millis(config.timeout_ms.unwrap_or(DEFAULT_AUTH_TIMEOUT_MS)); providers.insert( config.name.clone(), AuthProvider::Grpc(GrpcAuthProvider { channel, timeout, }), ); } _ => { return Err(config_error(format!( "unsupported auth provider type {}", config.provider_type )) .into()); } } } Ok(providers) } async fn build_credit_providers( configs: Vec, ) -> Result, Box> { let mut providers = HashMap::new(); for config in configs { let provider_type = normalize_name(&config.provider_type); if providers.contains_key(&config.name) { return Err(config_error(format!( "duplicate credit provider name {}", config.name )) .into()); } match provider_type.as_str() { "grpc" => { let mut endpoint = Endpoint::from_shared(config.endpoint.clone())? .connect_timeout(Duration::from_millis( config .timeout_ms .unwrap_or(DEFAULT_CREDIT_TIMEOUT_MS), )) .timeout(Duration::from_millis( config .timeout_ms .unwrap_or(DEFAULT_CREDIT_TIMEOUT_MS), )); if let Some(tls) = build_client_tls_config(&config.tls).await? { endpoint = endpoint.tls_config(tls)?; } let channel = endpoint.connect().await?; let timeout = Duration::from_millis( config .timeout_ms .unwrap_or(DEFAULT_CREDIT_TIMEOUT_MS), ); providers.insert( config.name.clone(), CreditProvider::Grpc(GrpcCreditProvider { channel, timeout, }), ); } _ => { return Err(config_error(format!( "unsupported credit provider type {}", config.provider_type )) .into()); } } } Ok(providers) } fn build_routes(configs: Vec) -> Result, Box> { let mut routes = Vec::new(); for mut config in configs { if config.name.trim().is_empty() { return Err(config_error("route name is required").into()); } let path_prefix = normalize_path_prefix(&config.path_prefix); config.path_prefix = path_prefix; let upstream = Url::parse(&config.upstream)?; if upstream.scheme() != "http" && upstream.scheme() != "https" { return Err(config_error(format!( "route {} upstream must be http or https", config.name )) .into()); } if upstream.host_str().is_none() { return Err(config_error(format!( "route {} upstream must include host", config.name )) .into()); } let upstream_base_path = normalize_upstream_base_path(upstream.path()); routes.push(Route { config, upstream, upstream_base_path, }); } routes.sort_by(|a, b| b.config.path_prefix.len().cmp(&a.config.path_prefix.len())); Ok(routes) } fn config_error(message: impl Into) -> io::Error { io::Error::new(io::ErrorKind::InvalidInput, message.into()) } fn normalize_name(value: &str) -> String { value.trim().to_lowercase().replace('-', "_") } fn extract_request_id(headers: &HeaderMap) -> String { headers .get(DEFAULT_REQUEST_ID_HEADER) .and_then(|value| value.to_str().ok()) .map(|value| value.to_string()) .unwrap_or_else(|| Uuid::new_v4().to_string()) } fn extract_client_ip( headers: &HeaderMap, remote_addr: SocketAddr, trust_forwarded_headers: bool, ) -> String { if trust_forwarded_headers { if let Some(value) = headers .get("x-forwarded-for") .and_then(|value| value.to_str().ok()) .and_then(|value| value.split(',').next()) { let trimmed = value.trim(); if !trimmed.is_empty() { return trimmed.to_string(); } } if let Some(value) = headers .get("x-real-ip") .and_then(|value| value.to_str().ok()) { let trimmed = value.trim(); if !trimmed.is_empty() { return trimmed.to_string(); } } } remote_addr.ip().to_string() } fn headers_to_map(headers: &HeaderMap) -> HashMap { let mut map: HashMap = HashMap::new(); for (name, value) in headers.iter() { if let Ok(value) = value.to_str() { map.entry(name.as_str().to_string()) .and_modify(|entry| { entry.push(','); entry.push_str(value); }) .or_insert_with(|| value.to_string()); } } map } fn extract_auth_token(headers: &HeaderMap) -> Option { let auth_header = headers .get(axum::http::header::AUTHORIZATION) .and_then(|value| value.to_str().ok()); if let Some(token) = auth_header.and_then(parse_auth_token_value) { return Some(token); } let photon_header = headers .get(PHOTON_AUTH_TOKEN_HEADER) .and_then(|value| value.to_str().ok()); photon_header.and_then(parse_auth_token_value) } fn is_reserved_auth_header(name: &axum::http::header::HeaderName) -> bool { is_reserved_auth_header_str(name.as_str()) } fn is_reserved_auth_header_str(name: &str) -> bool { let header = name.to_ascii_lowercase(); RESERVED_AUTH_HEADERS.iter().any(|value| *value == header) } fn should_forward_auth_header(name: &str) -> bool { let header = name.to_ascii_lowercase(); if AUTH_PROVIDER_BLOCK_HEADERS .iter() .any(|value| *value == header) { return false; } header.starts_with("x-") } fn should_preserve_client_auth_header(name: &str) -> bool { let header = name.to_ascii_lowercase(); header == "authorization" || header == PHOTON_AUTH_TOKEN_HEADER } fn parse_auth_token_value(value: &str) -> Option { let trimmed = value.trim(); if trimmed.is_empty() { return None; } if let Some(token) = parse_bearer_token(trimmed) { return Some(token); } // Legacy support: allow raw token values without a scheme. if trimmed.split_whitespace().count() != 1 { return None; } Some(trimmed.to_string()) } fn parse_bearer_token(value: &str) -> Option { let mut parts = value.split_whitespace(); let scheme = parts.next()?; if !scheme.eq_ignore_ascii_case("bearer") { return None; } let token = parts.next()?; if parts.next().is_some() { return None; } Some(token.to_string()) } fn normalize_path_prefix(prefix: &str) -> String { let trimmed = prefix.trim(); if trimmed.is_empty() { return "/".to_string(); } let mut normalized = if trimmed.starts_with('/') { trimmed.to_string() } else { format!("/{}", trimmed) }; if normalized.len() > 1 && normalized.ends_with('/') { normalized.pop(); } normalized } fn normalize_upstream_base_path(path: &str) -> String { let trimmed = path.trim(); if trimmed.is_empty() || trimmed == "/" { "/".to_string() } else { trimmed.trim_end_matches('/').to_string() } } fn match_route<'a>(routes: &'a [Route], path: &str) -> Option<&'a Route> { routes .iter() .find(|route| path_matches_prefix(path, &route.config.path_prefix)) } fn path_matches_prefix(path: &str, prefix: &str) -> bool { if prefix == "/" { return true; } if path == prefix { return true; } match path.strip_prefix(prefix) { Some(stripped) => stripped.starts_with('/'), None => false, } } fn strip_prefix_path(path: &str, prefix: &str) -> String { if prefix == "/" { return path.to_string(); } match path.strip_prefix(prefix) { Some("") => "/".to_string(), Some(stripped) => { if stripped.starts_with('/') { stripped.to_string() } else { format!("/{}", stripped) } } None => path.to_string(), } } fn join_paths(base: &str, path: &str) -> String { if base == "/" { return path.to_string(); } if path == "/" { let trimmed = base.trim_end_matches('/'); return if trimmed.is_empty() { "/".to_string() } else { trimmed.to_string() }; } format!( "{}/{}", base.trim_end_matches('/'), path.trim_start_matches('/') ) } fn build_upstream_url(route: &Route, uri: &Uri) -> Result { let path = if route.config.strip_prefix { strip_prefix_path(uri.path(), &route.config.path_prefix) } else { uri.path().to_string() }; let merged_path = join_paths(&route.upstream_base_path, &path); let mut url = route.upstream.clone(); url.set_path(&merged_path); url.set_query(uri.query()); Ok(url) } #[cfg(test)] mod tests { use super::*; use axum::routing::get; use creditservice_api::{CreditServiceImpl, CreditStorage, GatewayCreditServiceImpl}; use apigateway_api::GatewayCreditServiceServer; use creditservice_types::Wallet; use iam_api::{GatewayAuthServiceImpl, GatewayAuthServiceServer}; use iam_authn::{InternalTokenConfig, InternalTokenService, SigningKey}; use iam_authz::{PolicyCache, PolicyEvaluator}; use iam_store::{Backend, BackendConfig, BindingStore, PrincipalStore, RoleStore, TokenStore}; use iam_types::{Permission, PolicyBinding, Principal, PrincipalRef, Role, Scope}; use tokio_stream::wrappers::TcpListenerStream; use tonic::transport::Server; use uuid::Uuid; async fn wait_for_test_tcp(addr: SocketAddr) { let deadline = tokio::time::Instant::now() + Duration::from_secs(2); loop { if tokio::net::TcpStream::connect(addr).await.is_ok() { return; } assert!( tokio::time::Instant::now() < deadline, "timed out waiting for test listener {}", addr ); tokio::time::sleep(Duration::from_millis(25)).await; } } fn route_config(name: &str, prefix: &str, upstream: &str, strip_prefix: bool) -> RouteConfig { RouteConfig { name: name.to_string(), path_prefix: prefix.to_string(), upstream: upstream.to_string(), strip_prefix, timeout_ms: None, auth: None, credit: None, } } fn auth_decision(allow: bool) -> AuthDecision { AuthDecision { allow, subject: None, headers: HashMap::new(), reason: None, } } fn credit_decision(allow: bool) -> CreditDecision { CreditDecision { allow, reservation_id: "resv".to_string(), reason: None, } } async fn start_upstream() -> SocketAddr { let app = Router::new() .route("/v1/echo", get(|| async { "ok" })) .route( "/v1/echo-auth", get(|headers: HeaderMap| async move { Json(serde_json::json!({ "authorization": headers .get(axum::http::header::AUTHORIZATION) .and_then(|value| value.to_str().ok()), "photon_token": headers .get(PHOTON_AUTH_TOKEN_HEADER) .and_then(|value| value.to_str().ok()), })) }), ); let listener = tokio::net::TcpListener::bind("127.0.0.1:0") .await .expect("bind upstream"); let addr = listener.local_addr().expect("upstream addr"); tokio::spawn(async move { axum::serve(listener, app).await.expect("upstream serve"); }); wait_for_test_tcp(addr).await; addr } async fn start_iam_gateway() -> (SocketAddr, String) { let backend = Arc::new(Backend::new(BackendConfig::Memory).await.expect("iam backend")); let principal_store = Arc::new(PrincipalStore::new(backend.clone())); let role_store = Arc::new(RoleStore::new(backend.clone())); let binding_store = Arc::new(BindingStore::new(backend.clone())); let token_store = Arc::new(TokenStore::new(backend)); let signing_key = SigningKey::generate("gateway-test-key"); let token_config = InternalTokenConfig::new(signing_key, "iam-gateway-test"); let token_service = Arc::new(InternalTokenService::new(token_config)); let mut principal = Principal::new_user("user-1", "User One"); principal.org_id = Some("org-1".into()); principal.project_id = Some("proj-1".into()); principal_store .create(&principal) .await .expect("principal create"); let role = Role::new( "GatewayReader", Scope::project("proj-1", "org-1"), vec![Permission::new("gateway:public:read", "*")], ); role_store.create(&role).await.expect("role create"); let binding = PolicyBinding::new( format!("binding-{}", Uuid::new_v4()), PrincipalRef::new(principal.kind.clone(), principal.id.clone()), role.to_ref(), Scope::project("proj-1", "org-1"), ); binding_store .create(&binding) .await .expect("binding create"); let issued = token_service .issue(&principal, vec![], Scope::project("proj-1", "org-1"), None) .await .expect("issue token"); let cache = Arc::new(PolicyCache::default_config()); let evaluator = Arc::new(PolicyEvaluator::new( binding_store.clone(), role_store.clone(), cache, )); let gateway_auth = GatewayAuthServiceImpl::new( token_service, principal_store, token_store, evaluator, ); let listener = tokio::net::TcpListener::bind("127.0.0.1:0") .await .expect("bind iam"); let addr = listener.local_addr().expect("iam addr"); tokio::spawn(async move { Server::builder() .add_service(GatewayAuthServiceServer::new(gateway_auth)) .serve_with_incoming(TcpListenerStream::new(listener)) .await .expect("iam gateway serve"); }); wait_for_test_tcp(addr).await; (addr, issued.token) } async fn start_credit_gateway(iam_addr: &SocketAddr) -> SocketAddr { let storage = creditservice_api::InMemoryStorage::new(); let wallet = Wallet::new("proj-1".into(), "org-1".into(), 100); storage .create_wallet(wallet) .await .expect("wallet create"); let auth_service = Arc::new( iam_service_auth::AuthService::new(&format!("http://{}", iam_addr)) .await .expect("auth service"), ); let credit_service = Arc::new(CreditServiceImpl::new(storage, auth_service)); let gateway_credit = GatewayCreditServiceImpl::new(credit_service); let listener = tokio::net::TcpListener::bind("127.0.0.1:0") .await .expect("bind credit"); let addr = listener.local_addr().expect("credit addr"); tokio::spawn(async move { Server::builder() .add_service(GatewayCreditServiceServer::new(gateway_credit)) .serve_with_incoming(TcpListenerStream::new(listener)) .await .expect("credit gateway serve"); }); wait_for_test_tcp(addr).await; addr } #[test] fn test_normalize_path_prefix() { assert_eq!(normalize_path_prefix(""), "/"); assert_eq!(normalize_path_prefix("/"), "/"); assert_eq!(normalize_path_prefix("api"), "/api"); assert_eq!(normalize_path_prefix("/api/"), "/api"); } #[test] fn test_strip_prefix_path() { assert_eq!(strip_prefix_path("/api", "/api"), "/"); assert_eq!(strip_prefix_path("/api/v1", "/api"), "/v1"); assert_eq!(strip_prefix_path("/api/v1/", "/api"), "/v1/"); assert_eq!(strip_prefix_path("/v1", "/"), "/v1"); } #[test] fn test_join_paths() { assert_eq!(join_paths("/", "/v1"), "/v1"); assert_eq!(join_paths("/api", "/v1"), "/api/v1"); assert_eq!(join_paths("/api/", "/"), "/api"); } #[test] fn test_match_route_longest_prefix() { let routes = build_routes(vec![ route_config("api", "/api", "http://example.com", false), route_config("api-v1", "/api/v1", "http://example.com", false), ]) .unwrap(); let matched = match_route(&routes, "/api/v1/users").unwrap(); assert_eq!(matched.config.name, "api-v1"); } #[test] fn test_match_route_segment_boundary() { let routes = build_routes(vec![ route_config("api", "/api", "http://example.com", false), route_config("api2", "/api2", "http://example.com", false), ]) .unwrap(); let matched = match_route(&routes, "/api2").unwrap(); assert_eq!(matched.config.name, "api2"); let matched = match_route(&routes, "/api2/health").unwrap(); assert_eq!(matched.config.name, "api2"); assert!(match_route(&routes, "/apiary").is_none()); } #[test] fn test_build_upstream_url_preserves_query() { let routes = build_routes(vec![route_config( "api", "/api", "http://example.com/base", false, )]) .unwrap(); let route = routes.first().unwrap(); let uri: Uri = "/api/v1/users?debug=true".parse().unwrap(); let url = build_upstream_url(route, &uri).unwrap(); assert_eq!(url.as_str(), "http://example.com/base/api/v1/users?debug=true"); } #[test] fn test_build_upstream_url_strip_prefix() { let routes = build_routes(vec![route_config( "api", "/api", "http://example.com/base", true, )]) .unwrap(); let route = routes.first().unwrap(); let uri: Uri = "/api/v1".parse().unwrap(); let url = build_upstream_url(route, &uri).unwrap(); assert_eq!(url.as_str(), "http://example.com/base/v1"); } #[test] fn test_apply_auth_mode_required() { let decision = Ok(auth_decision(true)); let outcome = apply_auth_mode(PolicyMode::Required, false, decision).unwrap(); assert!(outcome.subject.is_none()); let decision = Ok(auth_decision(false)); let err = apply_auth_mode(PolicyMode::Required, false, decision).unwrap_err(); assert_eq!(err, StatusCode::FORBIDDEN); } #[test] fn test_apply_auth_mode_optional() { let decision = Ok(auth_decision(false)); let outcome = apply_auth_mode(PolicyMode::Optional, false, decision).unwrap(); assert!(outcome.subject.is_none()); let outcome = apply_auth_mode(PolicyMode::Optional, false, Err(StatusCode::BAD_GATEWAY)).unwrap(); assert!(outcome.subject.is_none()); } #[test] fn test_apply_credit_mode_required() { let decision = Ok(credit_decision(true)); let outcome = apply_credit_mode(PolicyMode::Required, false, decision).unwrap(); assert!(outcome.is_some()); let decision = Ok(credit_decision(false)); let err = apply_credit_mode(PolicyMode::Required, false, decision).unwrap_err(); assert_eq!(err, StatusCode::PAYMENT_REQUIRED); } #[test] fn test_apply_credit_mode_optional() { let decision = Ok(credit_decision(false)); let outcome = apply_credit_mode(PolicyMode::Optional, false, decision).unwrap(); assert!(outcome.is_none()); let outcome = apply_credit_mode(PolicyMode::Optional, false, Err(StatusCode::BAD_GATEWAY)).unwrap(); assert!(outcome.is_none()); } #[tokio::test] async fn test_gateway_auth_and_credit_flow() { let upstream_addr = start_upstream().await; let (iam_addr, token) = start_iam_gateway().await; let credit_addr = start_credit_gateway(&iam_addr).await; let routes = build_routes(vec![RouteConfig { name: "public".to_string(), path_prefix: "/v1".to_string(), upstream: format!("http://{}", upstream_addr), strip_prefix: false, timeout_ms: None, auth: Some(RouteAuthConfig { provider: "iam".to_string(), mode: PolicyMode::Required, fail_open: false, }), credit: Some(RouteCreditConfig { provider: "credit".to_string(), mode: PolicyMode::Required, units: 1, fail_open: false, commit_on: CommitPolicy::Success, allow_header_subject: false, attributes: HashMap::new(), }), }]) .unwrap(); let auth_providers = build_auth_providers(vec![AuthProviderConfig { name: "iam".to_string(), provider_type: "grpc".to_string(), endpoint: format!("http://{}", iam_addr), timeout_ms: Some(1000), tls: None, }]) .await .unwrap(); let credit_providers = build_credit_providers(vec![CreditProviderConfig { name: "credit".to_string(), provider_type: "grpc".to_string(), endpoint: format!("http://{}", credit_addr), timeout_ms: Some(1000), tls: None, }]) .await .unwrap(); let state = Arc::new(ServerState { routes, client: Client::new(), upstream_timeout: Duration::from_secs(5), max_body_bytes: 1024 * 1024, max_response_bytes: 1024 * 1024, auth_providers, credit_providers, trust_forwarded_headers: false, }); let deadline = tokio::time::Instant::now() + Duration::from_secs(10); let mut response = None; while tokio::time::Instant::now() < deadline { let request = Request::builder() .method("GET") .uri("/v1/echo") .header(axum::http::header::AUTHORIZATION, &token) .body(Body::empty()) .expect("request build"); match proxy( State(Arc::clone(&state)), ConnectInfo("127.0.0.1:40000".parse().unwrap()), request, ) .await { Ok(ok) => { response = Some(ok); break; } Err(StatusCode::BAD_GATEWAY) => { tokio::time::sleep(Duration::from_millis(25)).await; } Err(status) => panic!("unexpected proxy status: {}", status), } } let response = response.expect("gateway auth+credit test timed out waiting for ready backends"); assert_eq!(response.status(), StatusCode::OK); } #[tokio::test] async fn test_proxy_forwards_client_auth_headers_when_route_has_no_auth() { let upstream_addr = start_upstream().await; let routes = build_routes(vec![route_config( "passthrough", "/v1", &format!("http://{}", upstream_addr), false, )]) .unwrap(); let state = Arc::new(ServerState { routes, client: Client::new(), upstream_timeout: Duration::from_secs(5), max_body_bytes: 1024 * 1024, max_response_bytes: 1024 * 1024, auth_providers: HashMap::new(), credit_providers: HashMap::new(), trust_forwarded_headers: false, }); let request = Request::builder() .method("GET") .uri("/v1/echo-auth") .header(axum::http::header::AUTHORIZATION, "Bearer passthrough-token") .header(PHOTON_AUTH_TOKEN_HEADER, "photon-token") .body(Body::empty()) .expect("request build"); let response = proxy( State(state), ConnectInfo("127.0.0.1:40000".parse().unwrap()), request, ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = to_bytes(response.into_body(), 1024 * 1024).await.unwrap(); let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); assert_eq!(json.get("authorization").and_then(|v| v.as_str()), Some("Bearer passthrough-token")); assert_eq!(json.get("photon_token").and_then(|v| v.as_str()), Some("photon-token")); } #[test] fn test_extract_auth_token_accepts_bearer_authorization() { let mut headers = HeaderMap::new(); headers.insert( axum::http::header::AUTHORIZATION, "Bearer abc123".parse().unwrap(), ); assert_eq!(extract_auth_token(&headers).as_deref(), Some("abc123")); } #[test] fn test_extract_auth_token_accepts_legacy_raw_authorization() { let mut headers = HeaderMap::new(); headers.insert( axum::http::header::AUTHORIZATION, "raw-token".parse().unwrap(), ); assert_eq!(extract_auth_token(&headers).as_deref(), Some("raw-token")); } #[test] fn test_extract_auth_token_falls_back_to_photon_header() { let mut headers = HeaderMap::new(); headers.insert( axum::http::header::AUTHORIZATION, "Basic abc".parse().unwrap(), ); headers.insert(PHOTON_AUTH_TOKEN_HEADER, "photon-token".parse().unwrap()); assert_eq!( extract_auth_token(&headers).as_deref(), Some("photon-token") ); } }