photoncloud-monorepo/apigateway/crates/apigateway-server/src/main.rs

1871 lines
56 KiB
Rust

use std::collections::HashMap;
use std::io;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use apigateway_api::proto::{
AuthorizeRequest, CreditCommitRequest, CreditReserveRequest, CreditRollbackRequest,
};
use apigateway_api::{GatewayAuthServiceClient, GatewayCreditServiceClient};
use axum::{
body::{to_bytes, Body},
extract::{ConnectInfo, State},
http::{HeaderMap, Request, StatusCode, Uri},
response::Response,
routing::{any, get},
Json, Router,
};
use clap::Parser;
use reqwest::{Client, Url};
use serde::{Deserialize, Serialize};
use tonic::transport::{Certificate, Channel, ClientTlsConfig, Endpoint, Identity};
use tonic::Request as TonicRequest;
use tracing::{info, warn};
use tracing_subscriber::EnvFilter;
use uuid::Uuid;
const DEFAULT_REQUEST_ID_HEADER: &str = "x-request-id";
const PHOTON_AUTH_TOKEN_HEADER: &str = "x-photon-auth-token";
const DEFAULT_AUTH_TIMEOUT_MS: u64 = 500;
const DEFAULT_CREDIT_TIMEOUT_MS: u64 = 500;
const DEFAULT_UPSTREAM_TIMEOUT_MS: u64 = 10_000;
const RESERVED_AUTH_HEADERS: [&str; 10] = [
"authorization",
"x-photon-auth-token",
"x-subject-id",
"x-org-id",
"x-project-id",
"x-roles",
"x-scopes",
"x-iam-session-id",
"x-iam-principal-kind",
"x-iam-auth-method",
];
const AUTH_PROVIDER_BLOCK_HEADERS: [&str; 17] = [
"authorization",
"x-photon-auth-token",
"x-subject-id",
"x-org-id",
"x-project-id",
"x-roles",
"x-scopes",
"proxy-authorization",
"cookie",
"set-cookie",
"host",
"connection",
"upgrade",
"keep-alive",
"te",
"trailer",
"transfer-encoding",
];
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
enum PolicyMode {
Disabled,
Optional,
Required,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
enum CommitPolicy {
Success,
Always,
Never,
}
fn default_policy_mode() -> PolicyMode {
PolicyMode::Required
}
fn default_commit_policy() -> CommitPolicy {
CommitPolicy::Success
}
fn default_credit_units() -> u64 {
1
}
fn default_upstream_timeout_ms() -> u64 {
DEFAULT_UPSTREAM_TIMEOUT_MS
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct TlsConfig {
#[serde(default)]
ca_file: Option<String>,
#[serde(default)]
cert_file: Option<String>,
#[serde(default)]
key_file: Option<String>,
#[serde(default)]
domain_name: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct AuthProviderConfig {
name: String,
#[serde(rename = "type")]
provider_type: String,
endpoint: String,
#[serde(default)]
timeout_ms: Option<u64>,
#[serde(default)]
tls: Option<TlsConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct CreditProviderConfig {
name: String,
#[serde(rename = "type")]
provider_type: String,
endpoint: String,
#[serde(default)]
timeout_ms: Option<u64>,
#[serde(default)]
tls: Option<TlsConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct RouteAuthConfig {
provider: String,
#[serde(default = "default_policy_mode")]
mode: PolicyMode,
#[serde(default)]
fail_open: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct RouteCreditConfig {
provider: String,
#[serde(default = "default_policy_mode")]
mode: PolicyMode,
#[serde(default = "default_credit_units")]
units: u64,
#[serde(default)]
fail_open: bool,
#[serde(default = "default_commit_policy")]
commit_on: CommitPolicy,
#[serde(default)]
allow_header_subject: bool,
#[serde(default)]
attributes: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct RouteConfig {
name: String,
path_prefix: String,
upstream: String,
#[serde(default)]
strip_prefix: bool,
#[serde(default)]
timeout_ms: Option<u64>,
#[serde(default)]
auth: Option<RouteAuthConfig>,
#[serde(default)]
credit: Option<RouteCreditConfig>,
}
#[derive(Clone)]
struct Route {
config: RouteConfig,
upstream: Url,
upstream_base_path: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ServerConfig {
#[serde(default = "default_http_addr")]
http_addr: SocketAddr,
#[serde(default = "default_log_level")]
log_level: String,
#[serde(default = "default_max_body_bytes")]
max_body_bytes: usize,
#[serde(default = "default_max_response_bytes")]
max_response_bytes: usize,
#[serde(default = "default_upstream_timeout_ms")]
upstream_timeout_ms: u64,
#[serde(default)]
trust_forwarded_headers: bool,
#[serde(default)]
auth_providers: Vec<AuthProviderConfig>,
#[serde(default)]
credit_providers: Vec<CreditProviderConfig>,
#[serde(default)]
routes: Vec<RouteConfig>,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
http_addr: default_http_addr(),
log_level: default_log_level(),
max_body_bytes: default_max_body_bytes(),
max_response_bytes: default_max_response_bytes(),
upstream_timeout_ms: default_upstream_timeout_ms(),
trust_forwarded_headers: false,
auth_providers: Vec::new(),
credit_providers: Vec::new(),
routes: Vec::new(),
}
}
}
#[derive(Debug, Parser)]
#[command(author, version, about)]
struct Args {
/// Configuration file path
#[arg(short, long, default_value = "apigateway.toml")]
config: PathBuf,
/// HTTP listen address (overrides config)
#[arg(long)]
http_addr: Option<String>,
/// Log level (overrides config)
#[arg(short, long)]
log_level: Option<String>,
}
#[derive(Clone)]
struct ServerState {
routes: Vec<Route>,
client: Client,
upstream_timeout: Duration,
max_body_bytes: usize,
max_response_bytes: usize,
auth_providers: HashMap<String, AuthProvider>,
credit_providers: HashMap<String, CreditProvider>,
trust_forwarded_headers: bool,
}
#[derive(Clone)]
struct GrpcAuthProvider {
channel: Channel,
timeout: Duration,
}
#[derive(Clone)]
enum AuthProvider {
Grpc(GrpcAuthProvider),
}
#[derive(Clone)]
struct GrpcCreditProvider {
channel: Channel,
timeout: Duration,
}
#[derive(Clone)]
enum CreditProvider {
Grpc(GrpcCreditProvider),
}
#[derive(Clone, Debug)]
struct SubjectInfo {
subject_id: String,
org_id: String,
project_id: String,
roles: Vec<String>,
scopes: Vec<String>,
}
#[derive(Clone, Debug)]
struct CreditSubject {
subject_id: String,
org_id: String,
project_id: String,
}
#[derive(Clone, Debug)]
struct AuthDecision {
allow: bool,
subject: Option<SubjectInfo>,
headers: HashMap<String, String>,
reason: Option<String>,
}
#[derive(Clone, Debug)]
struct CreditDecision {
allow: bool,
reservation_id: String,
reason: Option<String>,
}
#[derive(Clone, Debug, Default)]
struct AuthOutcome {
subject: Option<SubjectInfo>,
headers: HashMap<String, String>,
}
#[derive(Clone, Debug)]
struct CreditReservation {
provider: String,
reservation_id: String,
}
#[derive(Clone, Debug)]
struct RequestContext {
request_id: String,
method: String,
path: String,
raw_query: String,
headers: HashMap<String, String>,
client_ip: String,
route_name: String,
}
fn default_http_addr() -> SocketAddr {
"127.0.0.1:8080"
.parse()
.expect("invalid default HTTP address")
}
fn default_log_level() -> String {
"info".to_string()
}
fn default_max_body_bytes() -> usize {
16 * 1024 * 1024
}
fn default_max_response_bytes() -> usize {
default_max_body_bytes()
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();
let mut used_default_config = false;
let mut config = if args.config.exists() {
let contents = tokio::fs::read_to_string(&args.config).await?;
toml::from_str(&contents)?
} else {
used_default_config = true;
ServerConfig::default()
};
if let Some(http_addr) = args.http_addr {
config.http_addr = http_addr.parse()?;
}
if let Some(log_level) = args.log_level {
config.log_level = log_level;
}
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&config.log_level)),
)
.init();
if used_default_config {
info!("Config file not found: {}, using defaults", args.config.display());
}
let routes = build_routes(config.routes)?;
let auth_providers = build_auth_providers(config.auth_providers).await?;
let credit_providers = build_credit_providers(config.credit_providers).await?;
let upstream_timeout = Duration::from_millis(config.upstream_timeout_ms);
let client = Client::builder().build()?;
info!("Starting API gateway");
info!(" HTTP: {}", config.http_addr);
info!(" Max body bytes: {}", config.max_body_bytes);
info!(" Max response bytes: {}", config.max_response_bytes);
if !routes.is_empty() {
info!("Configured {} routes", routes.len());
} else {
warn!("No routes configured; proxy will return 404s");
}
if auth_providers.is_empty() {
warn!("No auth providers configured");
}
if credit_providers.is_empty() {
warn!("No credit providers configured");
}
let state = Arc::new(ServerState {
routes,
client,
upstream_timeout,
max_body_bytes: config.max_body_bytes,
max_response_bytes: config.max_response_bytes,
auth_providers,
credit_providers,
trust_forwarded_headers: config.trust_forwarded_headers,
});
let app = Router::new()
.route("/", get(index))
.route("/health", get(health))
.route("/routes", get(list_routes))
.route("/*path", any(proxy))
.with_state(state);
let listener = tokio::net::TcpListener::bind(config.http_addr).await?;
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()).await?;
Ok(())
}
async fn index() -> &'static str {
"apigateway-server running"
}
async fn health() -> Json<serde_json::Value> {
Json(serde_json::json!({"status": "ok"}))
}
async fn list_routes(State(state): State<Arc<ServerState>>) -> Json<Vec<RouteConfig>> {
Json(state.routes.iter().map(|route| route.config.clone()).collect())
}
async fn proxy(
State(state): State<Arc<ServerState>>,
ConnectInfo(remote_addr): ConnectInfo<SocketAddr>,
request: Request<Body>,
) -> Result<Response<Body>, StatusCode> {
let path = request.uri().path();
let route = match_route(&state.routes, path)
.ok_or(StatusCode::NOT_FOUND)?
.clone();
let request_id = extract_request_id(request.headers());
let context = RequestContext {
request_id: request_id.clone(),
method: request.method().to_string(),
path: request.uri().path().to_string(),
raw_query: request.uri().query().unwrap_or("").to_string(),
headers: headers_to_map(request.headers()),
client_ip: extract_client_ip(
request.headers(),
remote_addr,
state.trust_forwarded_headers,
),
route_name: route.config.name.clone(),
};
let auth_token = extract_auth_token(request.headers());
let forward_client_auth_headers = route.config.auth.is_none();
let auth_outcome = enforce_auth(&state, &route, &context, auth_token).await?;
let credit_reservation =
enforce_credit(&state, &route, &context, auth_outcome.subject.as_ref()).await?;
let target_url = build_upstream_url(&route, request.uri())?;
let request_timeout =
Duration::from_millis(route.config.timeout_ms.unwrap_or(state.upstream_timeout.as_millis() as u64));
let mut builder = state
.client
.request(request.method().clone(), target_url)
.timeout(request_timeout);
for (name, value) in request.headers().iter() {
if name == axum::http::header::HOST || name == axum::http::header::CONNECTION {
continue;
}
if is_reserved_auth_header(name) {
if forward_client_auth_headers && should_preserve_client_auth_header(name.as_str()) {
builder = builder.header(name, value);
}
continue;
}
builder = builder.header(name, value);
}
builder = builder.header(DEFAULT_REQUEST_ID_HEADER, request_id.clone());
builder = apply_auth_headers(builder, &auth_outcome);
let body_bytes = to_bytes(request.into_body(), state.max_body_bytes)
.await
.map_err(|_| StatusCode::PAYLOAD_TOO_LARGE)?;
let response = match builder.body(body_bytes).send().await {
Ok(response) => response,
Err(_) => {
finalize_credit_abort(&state, &route, credit_reservation).await;
return Err(StatusCode::BAD_GATEWAY);
}
};
let status = response.status();
if let Some(content_length) = response.content_length() {
if state.max_response_bytes > 0 && content_length as usize > state.max_response_bytes {
finalize_credit_abort(&state, &route, credit_reservation).await;
return Err(StatusCode::PAYLOAD_TOO_LARGE);
}
}
let mut response_builder = Response::builder().status(status);
let headers = response_builder
.headers_mut()
.ok_or(StatusCode::BAD_GATEWAY)?;
for (name, value) in response.headers().iter() {
if name == axum::http::header::CONNECTION {
continue;
}
headers.insert(name, value.clone());
}
let body = match response.bytes().await {
Ok(body) => body,
Err(_) => {
finalize_credit_abort(&state, &route, credit_reservation).await;
return Err(StatusCode::BAD_GATEWAY);
}
};
if state.max_response_bytes > 0 && body.len() > state.max_response_bytes {
finalize_credit_abort(&state, &route, credit_reservation).await;
return Err(StatusCode::PAYLOAD_TOO_LARGE);
}
finalize_credit(&state, &route, credit_reservation, status).await;
response_builder
.body(Body::from(body))
.map_err(|_| StatusCode::BAD_GATEWAY)
}
async fn enforce_auth(
state: &ServerState,
route: &Route,
context: &RequestContext,
token: Option<String>,
) -> Result<AuthOutcome, StatusCode> {
let Some(auth_cfg) = &route.config.auth else {
return Ok(AuthOutcome::default());
};
if auth_cfg.mode == PolicyMode::Disabled {
return Ok(AuthOutcome::default());
}
let decision = authorize_request(state, auth_cfg, context, token).await;
apply_auth_mode(auth_cfg.mode, auth_cfg.fail_open, decision)
}
fn apply_auth_mode(
mode: PolicyMode,
fail_open: bool,
decision: Result<AuthDecision, StatusCode>,
) -> Result<AuthOutcome, StatusCode> {
match mode {
PolicyMode::Disabled => Ok(AuthOutcome::default()),
PolicyMode::Optional => match decision {
Ok(decision) if decision.allow => Ok(AuthOutcome {
subject: decision.subject,
headers: decision.headers,
}),
Ok(decision) => {
if let Some(reason) = decision.reason {
warn!("Auth denied (optional mode): {}", reason);
}
Ok(AuthOutcome::default())
}
Err(err) => {
warn!("Auth provider error (optional mode): {}", err);
Ok(AuthOutcome::default())
}
},
PolicyMode::Required => match decision {
Ok(decision) if decision.allow => Ok(AuthOutcome {
subject: decision.subject,
headers: decision.headers,
}),
Ok(decision) => {
if let Some(reason) = decision.reason {
warn!("Auth denied (required mode): {}", reason);
}
Err(StatusCode::FORBIDDEN)
}
Err(err) => {
warn!("Auth provider error (required mode): {}", err);
if fail_open {
Ok(AuthOutcome::default())
} else {
Err(StatusCode::BAD_GATEWAY)
}
}
},
}
}
async fn enforce_credit(
state: &ServerState,
route: &Route,
context: &RequestContext,
subject: Option<&SubjectInfo>,
) -> Result<Option<CreditReservation>, StatusCode> {
let Some(credit_cfg) = &route.config.credit else {
return Ok(None);
};
if credit_cfg.mode == PolicyMode::Disabled {
return Ok(None);
}
let credit_subject = resolve_credit_subject(context, subject, credit_cfg.allow_header_subject);
if credit_subject.is_none() {
if credit_cfg.mode == PolicyMode::Required {
return Err(StatusCode::UNAUTHORIZED);
}
warn!("Credit skipped: missing org/project scope");
return Ok(None);
}
let decision = reserve_credit(
state,
credit_cfg,
context,
credit_subject.as_ref().expect("credit subject resolved"),
)
.await;
apply_credit_mode(credit_cfg.mode, credit_cfg.fail_open, decision)
.map(|decision| {
decision.map(|decision| CreditReservation {
provider: credit_cfg.provider.clone(),
reservation_id: decision.reservation_id,
})
})
}
fn apply_credit_mode(
mode: PolicyMode,
fail_open: bool,
decision: Result<CreditDecision, StatusCode>,
) -> Result<Option<CreditDecision>, StatusCode> {
match mode {
PolicyMode::Disabled => Ok(None),
PolicyMode::Optional => match decision {
Ok(decision) if decision.allow => Ok(Some(decision)),
Ok(decision) => {
if let Some(reason) = decision.reason {
warn!("Credit denied (optional mode): {}", reason);
}
Ok(None)
}
Err(err) => {
warn!("Credit provider error (optional mode): {}", err);
Ok(None)
}
},
PolicyMode::Required => match decision {
Ok(decision) if decision.allow => Ok(Some(decision)),
Ok(decision) => {
if let Some(reason) = decision.reason {
warn!("Credit denied (required mode): {}", reason);
}
Err(StatusCode::PAYMENT_REQUIRED)
}
Err(err) => {
warn!("Credit provider error (required mode): {}", err);
if fail_open {
Ok(None)
} else {
Err(StatusCode::BAD_GATEWAY)
}
}
},
}
}
async fn authorize_request(
state: &ServerState,
auth_cfg: &RouteAuthConfig,
context: &RequestContext,
token: Option<String>,
) -> Result<AuthDecision, StatusCode> {
let provider = state
.auth_providers
.get(&auth_cfg.provider)
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
match provider {
AuthProvider::Grpc(provider) => {
let mut client = GatewayAuthServiceClient::new(provider.channel.clone());
let mut request = TonicRequest::new(AuthorizeRequest {
request_id: context.request_id.clone(),
token: token.unwrap_or_default(),
method: context.method.clone(),
path: context.path.clone(),
raw_query: context.raw_query.clone(),
headers: context.headers.clone(),
client_ip: context.client_ip.clone(),
route_name: context.route_name.clone(),
});
request.set_timeout(provider.timeout);
let response = client
.authorize(request)
.await
.map_err(|_| StatusCode::BAD_GATEWAY)?
.into_inner();
let subject = response.subject.map(|subject| SubjectInfo {
subject_id: subject.subject_id,
org_id: subject.org_id,
project_id: subject.project_id,
roles: subject.roles,
scopes: subject.scopes,
});
Ok(AuthDecision {
allow: response.allow,
subject,
headers: response.headers,
reason: if response.reason.is_empty() {
None
} else {
Some(response.reason)
},
})
}
}
}
fn resolve_credit_subject(
context: &RequestContext,
subject: Option<&SubjectInfo>,
allow_header_subject: bool,
) -> Option<CreditSubject> {
if let Some(subject) = subject {
return Some(CreditSubject {
subject_id: subject.subject_id.clone(),
org_id: subject.org_id.clone(),
project_id: subject.project_id.clone(),
});
}
if !allow_header_subject {
return None;
}
let org_id = context.headers.get("x-org-id")?.trim();
let project_id = context.headers.get("x-project-id")?.trim();
if org_id.is_empty() || project_id.is_empty() {
return None;
}
let subject_id = context
.headers
.get("x-subject-id")
.map(|value| value.trim().to_string())
.unwrap_or_default();
Some(CreditSubject {
subject_id,
org_id: org_id.to_string(),
project_id: project_id.to_string(),
})
}
async fn reserve_credit(
state: &ServerState,
credit_cfg: &RouteCreditConfig,
context: &RequestContext,
credit_subject: &CreditSubject,
) -> Result<CreditDecision, StatusCode> {
let provider = state
.credit_providers
.get(&credit_cfg.provider)
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
let subject_id = credit_subject.subject_id.clone();
let org_id = credit_subject.org_id.clone();
let project_id = credit_subject.project_id.clone();
match provider {
CreditProvider::Grpc(provider) => {
let mut client = GatewayCreditServiceClient::new(provider.channel.clone());
let mut request = TonicRequest::new(CreditReserveRequest {
request_id: context.request_id.clone(),
subject_id,
org_id,
project_id,
route_name: context.route_name.clone(),
method: context.method.clone(),
path: context.path.clone(),
raw_query: context.raw_query.clone(),
units: credit_cfg.units,
attributes: credit_cfg.attributes.clone(),
});
request.set_timeout(provider.timeout);
let response = client
.reserve(request)
.await
.map_err(|_| StatusCode::BAD_GATEWAY)?
.into_inner();
Ok(CreditDecision {
allow: response.allow,
reservation_id: response.reservation_id,
reason: if response.reason.is_empty() {
None
} else {
Some(response.reason)
},
})
}
}
}
async fn finalize_credit(
state: &ServerState,
route: &Route,
reservation: Option<CreditReservation>,
status: reqwest::StatusCode,
) {
let Some(credit_cfg) = &route.config.credit else {
return;
};
let Some(reservation) = reservation else {
return;
};
match credit_cfg.commit_on {
CommitPolicy::Never => return,
CommitPolicy::Always => {
if let Err(err) = commit_credit(state, credit_cfg, &reservation).await {
warn!("Failed to commit credit reservation {}: {}", reservation.reservation_id, err);
}
}
CommitPolicy::Success => {
if status.is_success() || status.is_redirection() {
if let Err(err) = commit_credit(state, credit_cfg, &reservation).await {
warn!("Failed to commit credit reservation {}: {}", reservation.reservation_id, err);
}
} else if let Err(err) = rollback_credit(state, credit_cfg, &reservation).await {
warn!(
"Failed to rollback credit reservation {}: {}",
reservation.reservation_id, err
);
}
}
}
}
async fn finalize_credit_abort(
state: &ServerState,
route: &Route,
reservation: Option<CreditReservation>,
) {
let Some(credit_cfg) = &route.config.credit else {
return;
};
let Some(reservation) = reservation else {
return;
};
if credit_cfg.commit_on == CommitPolicy::Never {
return;
}
if let Err(err) = rollback_credit(state, credit_cfg, &reservation).await {
warn!(
"Failed to rollback credit reservation {} after delivery failure: {}",
reservation.reservation_id, err
);
}
}
async fn commit_credit(
state: &ServerState,
credit_cfg: &RouteCreditConfig,
reservation: &CreditReservation,
) -> Result<(), StatusCode> {
let provider = state
.credit_providers
.get(&reservation.provider)
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
match provider {
CreditProvider::Grpc(provider) => {
let mut client = GatewayCreditServiceClient::new(provider.channel.clone());
let mut request = TonicRequest::new(CreditCommitRequest {
reservation_id: reservation.reservation_id.clone(),
units: credit_cfg.units,
});
request.set_timeout(provider.timeout);
let response = client
.commit(request)
.await
.map_err(|_| StatusCode::BAD_GATEWAY)?
.into_inner();
if response.success {
Ok(())
} else {
Err(StatusCode::BAD_GATEWAY)
}
}
}
}
async fn rollback_credit(
state: &ServerState,
_credit_cfg: &RouteCreditConfig,
reservation: &CreditReservation,
) -> Result<(), StatusCode> {
let provider = state
.credit_providers
.get(&reservation.provider)
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
match provider {
CreditProvider::Grpc(provider) => {
let mut client = GatewayCreditServiceClient::new(provider.channel.clone());
let mut request = TonicRequest::new(CreditRollbackRequest {
reservation_id: reservation.reservation_id.clone(),
});
request.set_timeout(provider.timeout);
let response = client
.rollback(request)
.await
.map_err(|_| StatusCode::BAD_GATEWAY)?
.into_inner();
if response.success {
Ok(())
} else {
Err(StatusCode::BAD_GATEWAY)
}
}
}
}
fn apply_auth_headers(
mut builder: reqwest::RequestBuilder,
outcome: &AuthOutcome,
) -> reqwest::RequestBuilder {
for (key, value) in &outcome.headers {
if !should_forward_auth_header(key) {
continue;
}
builder = builder.header(key, value);
}
if let Some(subject) = &outcome.subject {
builder = builder
.header("x-subject-id", &subject.subject_id)
.header("x-org-id", &subject.org_id)
.header("x-project-id", &subject.project_id);
if !subject.roles.is_empty() {
builder = builder.header("x-roles", subject.roles.join(","));
}
if !subject.scopes.is_empty() {
builder = builder.header("x-scopes", subject.scopes.join(","));
}
}
builder
}
async fn build_client_tls_config(
tls: &Option<TlsConfig>,
) -> Result<Option<ClientTlsConfig>, Box<dyn std::error::Error>> {
let Some(tls) = tls else {
return Ok(None);
};
let mut tls_config = ClientTlsConfig::new();
if let Some(ca_file) = &tls.ca_file {
let ca = tokio::fs::read(ca_file).await?;
tls_config = tls_config.ca_certificate(Certificate::from_pem(ca));
}
match (&tls.cert_file, &tls.key_file) {
(Some(cert_file), Some(key_file)) => {
let cert = tokio::fs::read(cert_file).await?;
let key = tokio::fs::read(key_file).await?;
tls_config = tls_config.identity(Identity::from_pem(cert, key));
}
(None, None) => {}
_ => {
return Err(config_error("tls requires both cert_file and key_file").into());
}
}
if let Some(domain) = &tls.domain_name {
tls_config = tls_config.domain_name(domain);
}
Ok(Some(tls_config))
}
async fn build_auth_providers(
configs: Vec<AuthProviderConfig>,
) -> Result<HashMap<String, AuthProvider>, Box<dyn std::error::Error>> {
let mut providers = HashMap::new();
for config in configs {
let provider_type = normalize_name(&config.provider_type);
if providers.contains_key(&config.name) {
return Err(config_error(format!(
"duplicate auth provider name {}",
config.name
))
.into());
}
match provider_type.as_str() {
"grpc" => {
let mut endpoint = Endpoint::from_shared(config.endpoint.clone())?
.connect_timeout(Duration::from_millis(
config.timeout_ms.unwrap_or(DEFAULT_AUTH_TIMEOUT_MS),
))
.timeout(Duration::from_millis(
config.timeout_ms.unwrap_or(DEFAULT_AUTH_TIMEOUT_MS),
));
if let Some(tls) = build_client_tls_config(&config.tls).await? {
endpoint = endpoint.tls_config(tls)?;
}
let channel = endpoint.connect().await?;
let timeout =
Duration::from_millis(config.timeout_ms.unwrap_or(DEFAULT_AUTH_TIMEOUT_MS));
providers.insert(
config.name.clone(),
AuthProvider::Grpc(GrpcAuthProvider {
channel,
timeout,
}),
);
}
_ => {
return Err(config_error(format!(
"unsupported auth provider type {}",
config.provider_type
))
.into());
}
}
}
Ok(providers)
}
async fn build_credit_providers(
configs: Vec<CreditProviderConfig>,
) -> Result<HashMap<String, CreditProvider>, Box<dyn std::error::Error>> {
let mut providers = HashMap::new();
for config in configs {
let provider_type = normalize_name(&config.provider_type);
if providers.contains_key(&config.name) {
return Err(config_error(format!(
"duplicate credit provider name {}",
config.name
))
.into());
}
match provider_type.as_str() {
"grpc" => {
let mut endpoint = Endpoint::from_shared(config.endpoint.clone())?
.connect_timeout(Duration::from_millis(
config
.timeout_ms
.unwrap_or(DEFAULT_CREDIT_TIMEOUT_MS),
))
.timeout(Duration::from_millis(
config
.timeout_ms
.unwrap_or(DEFAULT_CREDIT_TIMEOUT_MS),
));
if let Some(tls) = build_client_tls_config(&config.tls).await? {
endpoint = endpoint.tls_config(tls)?;
}
let channel = endpoint.connect().await?;
let timeout = Duration::from_millis(
config
.timeout_ms
.unwrap_or(DEFAULT_CREDIT_TIMEOUT_MS),
);
providers.insert(
config.name.clone(),
CreditProvider::Grpc(GrpcCreditProvider {
channel,
timeout,
}),
);
}
_ => {
return Err(config_error(format!(
"unsupported credit provider type {}",
config.provider_type
))
.into());
}
}
}
Ok(providers)
}
fn build_routes(configs: Vec<RouteConfig>) -> Result<Vec<Route>, Box<dyn std::error::Error>> {
let mut routes = Vec::new();
for mut config in configs {
if config.name.trim().is_empty() {
return Err(config_error("route name is required").into());
}
let path_prefix = normalize_path_prefix(&config.path_prefix);
config.path_prefix = path_prefix;
let upstream = Url::parse(&config.upstream)?;
if upstream.scheme() != "http" && upstream.scheme() != "https" {
return Err(config_error(format!(
"route {} upstream must be http or https",
config.name
))
.into());
}
if upstream.host_str().is_none() {
return Err(config_error(format!(
"route {} upstream must include host",
config.name
))
.into());
}
let upstream_base_path = normalize_upstream_base_path(upstream.path());
routes.push(Route {
config,
upstream,
upstream_base_path,
});
}
routes.sort_by(|a, b| b.config.path_prefix.len().cmp(&a.config.path_prefix.len()));
Ok(routes)
}
fn config_error(message: impl Into<String>) -> io::Error {
io::Error::new(io::ErrorKind::InvalidInput, message.into())
}
fn normalize_name(value: &str) -> String {
value.trim().to_lowercase().replace('-', "_")
}
fn extract_request_id(headers: &HeaderMap) -> String {
headers
.get(DEFAULT_REQUEST_ID_HEADER)
.and_then(|value| value.to_str().ok())
.map(|value| value.to_string())
.unwrap_or_else(|| Uuid::new_v4().to_string())
}
fn extract_client_ip(
headers: &HeaderMap,
remote_addr: SocketAddr,
trust_forwarded_headers: bool,
) -> String {
if trust_forwarded_headers {
if let Some(value) = headers
.get("x-forwarded-for")
.and_then(|value| value.to_str().ok())
.and_then(|value| value.split(',').next())
{
let trimmed = value.trim();
if !trimmed.is_empty() {
return trimmed.to_string();
}
}
if let Some(value) = headers
.get("x-real-ip")
.and_then(|value| value.to_str().ok())
{
let trimmed = value.trim();
if !trimmed.is_empty() {
return trimmed.to_string();
}
}
}
remote_addr.ip().to_string()
}
fn headers_to_map(headers: &HeaderMap) -> HashMap<String, String> {
let mut map: HashMap<String, String> = HashMap::new();
for (name, value) in headers.iter() {
if let Ok(value) = value.to_str() {
map.entry(name.as_str().to_string())
.and_modify(|entry| {
entry.push(',');
entry.push_str(value);
})
.or_insert_with(|| value.to_string());
}
}
map
}
fn extract_auth_token(headers: &HeaderMap) -> Option<String> {
let auth_header = headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|value| value.to_str().ok());
if let Some(token) = auth_header.and_then(parse_auth_token_value) {
return Some(token);
}
let photon_header = headers
.get(PHOTON_AUTH_TOKEN_HEADER)
.and_then(|value| value.to_str().ok());
photon_header.and_then(parse_auth_token_value)
}
fn is_reserved_auth_header(name: &axum::http::header::HeaderName) -> bool {
is_reserved_auth_header_str(name.as_str())
}
fn is_reserved_auth_header_str(name: &str) -> bool {
let header = name.to_ascii_lowercase();
RESERVED_AUTH_HEADERS.iter().any(|value| *value == header)
}
fn should_forward_auth_header(name: &str) -> bool {
let header = name.to_ascii_lowercase();
if AUTH_PROVIDER_BLOCK_HEADERS
.iter()
.any(|value| *value == header)
{
return false;
}
header.starts_with("x-")
}
fn should_preserve_client_auth_header(name: &str) -> bool {
let header = name.to_ascii_lowercase();
header == "authorization" || header == PHOTON_AUTH_TOKEN_HEADER
}
fn parse_auth_token_value(value: &str) -> Option<String> {
let trimmed = value.trim();
if trimmed.is_empty() {
return None;
}
if let Some(token) = parse_bearer_token(trimmed) {
return Some(token);
}
// Legacy support: allow raw token values without a scheme.
if trimmed.split_whitespace().count() != 1 {
return None;
}
Some(trimmed.to_string())
}
fn parse_bearer_token(value: &str) -> Option<String> {
let mut parts = value.split_whitespace();
let scheme = parts.next()?;
if !scheme.eq_ignore_ascii_case("bearer") {
return None;
}
let token = parts.next()?;
if parts.next().is_some() {
return None;
}
Some(token.to_string())
}
fn normalize_path_prefix(prefix: &str) -> String {
let trimmed = prefix.trim();
if trimmed.is_empty() {
return "/".to_string();
}
let mut normalized = if trimmed.starts_with('/') {
trimmed.to_string()
} else {
format!("/{}", trimmed)
};
if normalized.len() > 1 && normalized.ends_with('/') {
normalized.pop();
}
normalized
}
fn normalize_upstream_base_path(path: &str) -> String {
let trimmed = path.trim();
if trimmed.is_empty() || trimmed == "/" {
"/".to_string()
} else {
trimmed.trim_end_matches('/').to_string()
}
}
fn match_route<'a>(routes: &'a [Route], path: &str) -> Option<&'a Route> {
routes
.iter()
.find(|route| path_matches_prefix(path, &route.config.path_prefix))
}
fn path_matches_prefix(path: &str, prefix: &str) -> bool {
if prefix == "/" {
return true;
}
if path == prefix {
return true;
}
match path.strip_prefix(prefix) {
Some(stripped) => stripped.starts_with('/'),
None => false,
}
}
fn strip_prefix_path(path: &str, prefix: &str) -> String {
if prefix == "/" {
return path.to_string();
}
match path.strip_prefix(prefix) {
Some("") => "/".to_string(),
Some(stripped) => {
if stripped.starts_with('/') {
stripped.to_string()
} else {
format!("/{}", stripped)
}
}
None => path.to_string(),
}
}
fn join_paths(base: &str, path: &str) -> String {
if base == "/" {
return path.to_string();
}
if path == "/" {
let trimmed = base.trim_end_matches('/');
return if trimmed.is_empty() { "/".to_string() } else { trimmed.to_string() };
}
format!(
"{}/{}",
base.trim_end_matches('/'),
path.trim_start_matches('/')
)
}
fn build_upstream_url(route: &Route, uri: &Uri) -> Result<Url, StatusCode> {
let path = if route.config.strip_prefix {
strip_prefix_path(uri.path(), &route.config.path_prefix)
} else {
uri.path().to_string()
};
let merged_path = join_paths(&route.upstream_base_path, &path);
let mut url = route.upstream.clone();
url.set_path(&merged_path);
url.set_query(uri.query());
Ok(url)
}
#[cfg(test)]
mod tests {
use super::*;
use axum::routing::get;
use creditservice_api::{CreditServiceImpl, CreditStorage, GatewayCreditServiceImpl};
use apigateway_api::GatewayCreditServiceServer;
use creditservice_types::Wallet;
use iam_api::{GatewayAuthServiceImpl, GatewayAuthServiceServer};
use iam_authn::{InternalTokenConfig, InternalTokenService, SigningKey};
use iam_authz::{PolicyCache, PolicyEvaluator};
use iam_store::{Backend, BackendConfig, BindingStore, PrincipalStore, RoleStore, TokenStore};
use iam_types::{Permission, PolicyBinding, Principal, PrincipalRef, Role, Scope};
use tokio_stream::wrappers::TcpListenerStream;
use tonic::transport::Server;
use uuid::Uuid;
async fn wait_for_test_tcp(addr: SocketAddr) {
let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
loop {
if tokio::net::TcpStream::connect(addr).await.is_ok() {
return;
}
assert!(
tokio::time::Instant::now() < deadline,
"timed out waiting for test listener {}",
addr
);
tokio::time::sleep(Duration::from_millis(25)).await;
}
}
fn route_config(name: &str, prefix: &str, upstream: &str, strip_prefix: bool) -> RouteConfig {
RouteConfig {
name: name.to_string(),
path_prefix: prefix.to_string(),
upstream: upstream.to_string(),
strip_prefix,
timeout_ms: None,
auth: None,
credit: None,
}
}
fn auth_decision(allow: bool) -> AuthDecision {
AuthDecision {
allow,
subject: None,
headers: HashMap::new(),
reason: None,
}
}
fn credit_decision(allow: bool) -> CreditDecision {
CreditDecision {
allow,
reservation_id: "resv".to_string(),
reason: None,
}
}
async fn start_upstream() -> SocketAddr {
let app = Router::new()
.route("/v1/echo", get(|| async { "ok" }))
.route(
"/v1/echo-auth",
get(|headers: HeaderMap| async move {
Json(serde_json::json!({
"authorization": headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|value| value.to_str().ok()),
"photon_token": headers
.get(PHOTON_AUTH_TOKEN_HEADER)
.and_then(|value| value.to_str().ok()),
}))
}),
);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind upstream");
let addr = listener.local_addr().expect("upstream addr");
tokio::spawn(async move {
axum::serve(listener, app).await.expect("upstream serve");
});
wait_for_test_tcp(addr).await;
addr
}
async fn start_iam_gateway() -> (SocketAddr, String) {
let backend = Arc::new(Backend::new(BackendConfig::Memory).await.expect("iam backend"));
let principal_store = Arc::new(PrincipalStore::new(backend.clone()));
let role_store = Arc::new(RoleStore::new(backend.clone()));
let binding_store = Arc::new(BindingStore::new(backend.clone()));
let token_store = Arc::new(TokenStore::new(backend));
let signing_key = SigningKey::generate("gateway-test-key");
let token_config = InternalTokenConfig::new(signing_key, "iam-gateway-test");
let token_service = Arc::new(InternalTokenService::new(token_config));
let mut principal = Principal::new_user("user-1", "User One");
principal.org_id = Some("org-1".into());
principal.project_id = Some("proj-1".into());
principal_store
.create(&principal)
.await
.expect("principal create");
let role = Role::new(
"GatewayReader",
Scope::project("proj-1", "org-1"),
vec![Permission::new("gateway:public:read", "*")],
);
role_store.create(&role).await.expect("role create");
let binding = PolicyBinding::new(
format!("binding-{}", Uuid::new_v4()),
PrincipalRef::new(principal.kind.clone(), principal.id.clone()),
role.to_ref(),
Scope::project("proj-1", "org-1"),
);
binding_store
.create(&binding)
.await
.expect("binding create");
let issued = token_service
.issue(&principal, vec![], Scope::project("proj-1", "org-1"), None)
.await
.expect("issue token");
let cache = Arc::new(PolicyCache::default_config());
let evaluator = Arc::new(PolicyEvaluator::new(
binding_store.clone(),
role_store.clone(),
cache,
));
let gateway_auth = GatewayAuthServiceImpl::new(
token_service,
principal_store,
token_store,
evaluator,
);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind iam");
let addr = listener.local_addr().expect("iam addr");
tokio::spawn(async move {
Server::builder()
.add_service(GatewayAuthServiceServer::new(gateway_auth))
.serve_with_incoming(TcpListenerStream::new(listener))
.await
.expect("iam gateway serve");
});
wait_for_test_tcp(addr).await;
(addr, issued.token)
}
async fn start_credit_gateway(iam_addr: &SocketAddr) -> SocketAddr {
let storage = creditservice_api::InMemoryStorage::new();
let wallet = Wallet::new("proj-1".into(), "org-1".into(), 100);
storage
.create_wallet(wallet)
.await
.expect("wallet create");
let auth_service = Arc::new(
iam_service_auth::AuthService::new(&format!("http://{}", iam_addr))
.await
.expect("auth service"),
);
let credit_service = Arc::new(CreditServiceImpl::new(storage, auth_service));
let gateway_credit = GatewayCreditServiceImpl::new(credit_service);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind credit");
let addr = listener.local_addr().expect("credit addr");
tokio::spawn(async move {
Server::builder()
.add_service(GatewayCreditServiceServer::new(gateway_credit))
.serve_with_incoming(TcpListenerStream::new(listener))
.await
.expect("credit gateway serve");
});
wait_for_test_tcp(addr).await;
addr
}
#[test]
fn test_normalize_path_prefix() {
assert_eq!(normalize_path_prefix(""), "/");
assert_eq!(normalize_path_prefix("/"), "/");
assert_eq!(normalize_path_prefix("api"), "/api");
assert_eq!(normalize_path_prefix("/api/"), "/api");
}
#[test]
fn test_strip_prefix_path() {
assert_eq!(strip_prefix_path("/api", "/api"), "/");
assert_eq!(strip_prefix_path("/api/v1", "/api"), "/v1");
assert_eq!(strip_prefix_path("/api/v1/", "/api"), "/v1/");
assert_eq!(strip_prefix_path("/v1", "/"), "/v1");
}
#[test]
fn test_join_paths() {
assert_eq!(join_paths("/", "/v1"), "/v1");
assert_eq!(join_paths("/api", "/v1"), "/api/v1");
assert_eq!(join_paths("/api/", "/"), "/api");
}
#[test]
fn test_match_route_longest_prefix() {
let routes = build_routes(vec![
route_config("api", "/api", "http://example.com", false),
route_config("api-v1", "/api/v1", "http://example.com", false),
])
.unwrap();
let matched = match_route(&routes, "/api/v1/users").unwrap();
assert_eq!(matched.config.name, "api-v1");
}
#[test]
fn test_match_route_segment_boundary() {
let routes = build_routes(vec![
route_config("api", "/api", "http://example.com", false),
route_config("api2", "/api2", "http://example.com", false),
])
.unwrap();
let matched = match_route(&routes, "/api2").unwrap();
assert_eq!(matched.config.name, "api2");
let matched = match_route(&routes, "/api2/health").unwrap();
assert_eq!(matched.config.name, "api2");
assert!(match_route(&routes, "/apiary").is_none());
}
#[test]
fn test_build_upstream_url_preserves_query() {
let routes = build_routes(vec![route_config(
"api",
"/api",
"http://example.com/base",
false,
)])
.unwrap();
let route = routes.first().unwrap();
let uri: Uri = "/api/v1/users?debug=true".parse().unwrap();
let url = build_upstream_url(route, &uri).unwrap();
assert_eq!(url.as_str(), "http://example.com/base/api/v1/users?debug=true");
}
#[test]
fn test_build_upstream_url_strip_prefix() {
let routes = build_routes(vec![route_config(
"api",
"/api",
"http://example.com/base",
true,
)])
.unwrap();
let route = routes.first().unwrap();
let uri: Uri = "/api/v1".parse().unwrap();
let url = build_upstream_url(route, &uri).unwrap();
assert_eq!(url.as_str(), "http://example.com/base/v1");
}
#[test]
fn test_apply_auth_mode_required() {
let decision = Ok(auth_decision(true));
let outcome = apply_auth_mode(PolicyMode::Required, false, decision).unwrap();
assert!(outcome.subject.is_none());
let decision = Ok(auth_decision(false));
let err = apply_auth_mode(PolicyMode::Required, false, decision).unwrap_err();
assert_eq!(err, StatusCode::FORBIDDEN);
}
#[test]
fn test_apply_auth_mode_optional() {
let decision = Ok(auth_decision(false));
let outcome = apply_auth_mode(PolicyMode::Optional, false, decision).unwrap();
assert!(outcome.subject.is_none());
let outcome = apply_auth_mode(PolicyMode::Optional, false, Err(StatusCode::BAD_GATEWAY)).unwrap();
assert!(outcome.subject.is_none());
}
#[test]
fn test_apply_credit_mode_required() {
let decision = Ok(credit_decision(true));
let outcome = apply_credit_mode(PolicyMode::Required, false, decision).unwrap();
assert!(outcome.is_some());
let decision = Ok(credit_decision(false));
let err = apply_credit_mode(PolicyMode::Required, false, decision).unwrap_err();
assert_eq!(err, StatusCode::PAYMENT_REQUIRED);
}
#[test]
fn test_apply_credit_mode_optional() {
let decision = Ok(credit_decision(false));
let outcome = apply_credit_mode(PolicyMode::Optional, false, decision).unwrap();
assert!(outcome.is_none());
let outcome = apply_credit_mode(PolicyMode::Optional, false, Err(StatusCode::BAD_GATEWAY)).unwrap();
assert!(outcome.is_none());
}
#[tokio::test]
async fn test_gateway_auth_and_credit_flow() {
let upstream_addr = start_upstream().await;
let (iam_addr, token) = start_iam_gateway().await;
let credit_addr = start_credit_gateway(&iam_addr).await;
let routes = build_routes(vec![RouteConfig {
name: "public".to_string(),
path_prefix: "/v1".to_string(),
upstream: format!("http://{}", upstream_addr),
strip_prefix: false,
timeout_ms: None,
auth: Some(RouteAuthConfig {
provider: "iam".to_string(),
mode: PolicyMode::Required,
fail_open: false,
}),
credit: Some(RouteCreditConfig {
provider: "credit".to_string(),
mode: PolicyMode::Required,
units: 1,
fail_open: false,
commit_on: CommitPolicy::Success,
allow_header_subject: false,
attributes: HashMap::new(),
}),
}])
.unwrap();
let auth_providers = build_auth_providers(vec![AuthProviderConfig {
name: "iam".to_string(),
provider_type: "grpc".to_string(),
endpoint: format!("http://{}", iam_addr),
timeout_ms: Some(1000),
tls: None,
}])
.await
.unwrap();
let credit_providers = build_credit_providers(vec![CreditProviderConfig {
name: "credit".to_string(),
provider_type: "grpc".to_string(),
endpoint: format!("http://{}", credit_addr),
timeout_ms: Some(1000),
tls: None,
}])
.await
.unwrap();
let state = Arc::new(ServerState {
routes,
client: Client::new(),
upstream_timeout: Duration::from_secs(5),
max_body_bytes: 1024 * 1024,
max_response_bytes: 1024 * 1024,
auth_providers,
credit_providers,
trust_forwarded_headers: false,
});
let deadline = tokio::time::Instant::now() + Duration::from_secs(10);
let mut response = None;
while tokio::time::Instant::now() < deadline {
let request = Request::builder()
.method("GET")
.uri("/v1/echo")
.header(axum::http::header::AUTHORIZATION, &token)
.body(Body::empty())
.expect("request build");
match proxy(
State(Arc::clone(&state)),
ConnectInfo("127.0.0.1:40000".parse().unwrap()),
request,
)
.await
{
Ok(ok) => {
response = Some(ok);
break;
}
Err(StatusCode::BAD_GATEWAY) => {
tokio::time::sleep(Duration::from_millis(25)).await;
}
Err(status) => panic!("unexpected proxy status: {}", status),
}
}
let response = response.expect("gateway auth+credit test timed out waiting for ready backends");
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_proxy_forwards_client_auth_headers_when_route_has_no_auth() {
let upstream_addr = start_upstream().await;
let routes = build_routes(vec![route_config(
"passthrough",
"/v1",
&format!("http://{}", upstream_addr),
false,
)])
.unwrap();
let state = Arc::new(ServerState {
routes,
client: Client::new(),
upstream_timeout: Duration::from_secs(5),
max_body_bytes: 1024 * 1024,
max_response_bytes: 1024 * 1024,
auth_providers: HashMap::new(),
credit_providers: HashMap::new(),
trust_forwarded_headers: false,
});
let request = Request::builder()
.method("GET")
.uri("/v1/echo-auth")
.header(axum::http::header::AUTHORIZATION, "Bearer passthrough-token")
.header(PHOTON_AUTH_TOKEN_HEADER, "photon-token")
.body(Body::empty())
.expect("request build");
let response = proxy(
State(state),
ConnectInfo("127.0.0.1:40000".parse().unwrap()),
request,
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = to_bytes(response.into_body(), 1024 * 1024).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json.get("authorization").and_then(|v| v.as_str()), Some("Bearer passthrough-token"));
assert_eq!(json.get("photon_token").and_then(|v| v.as_str()), Some("photon-token"));
}
#[test]
fn test_extract_auth_token_accepts_bearer_authorization() {
let mut headers = HeaderMap::new();
headers.insert(
axum::http::header::AUTHORIZATION,
"Bearer abc123".parse().unwrap(),
);
assert_eq!(extract_auth_token(&headers).as_deref(), Some("abc123"));
}
#[test]
fn test_extract_auth_token_accepts_legacy_raw_authorization() {
let mut headers = HeaderMap::new();
headers.insert(
axum::http::header::AUTHORIZATION,
"raw-token".parse().unwrap(),
);
assert_eq!(extract_auth_token(&headers).as_deref(), Some("raw-token"));
}
#[test]
fn test_extract_auth_token_falls_back_to_photon_header() {
let mut headers = HeaderMap::new();
headers.insert(
axum::http::header::AUTHORIZATION,
"Basic abc".parse().unwrap(),
);
headers.insert(PHOTON_AUTH_TOKEN_HEADER, "photon-token".parse().unwrap());
assert_eq!(
extract_auth_token(&headers).as_deref(),
Some("photon-token")
);
}
}