1871 lines
56 KiB
Rust
1871 lines
56 KiB
Rust
use std::collections::HashMap;
|
|
use std::io;
|
|
use std::net::SocketAddr;
|
|
use std::path::PathBuf;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
|
|
use apigateway_api::proto::{
|
|
AuthorizeRequest, CreditCommitRequest, CreditReserveRequest, CreditRollbackRequest,
|
|
};
|
|
use apigateway_api::{GatewayAuthServiceClient, GatewayCreditServiceClient};
|
|
use axum::{
|
|
body::{to_bytes, Body},
|
|
extract::{ConnectInfo, State},
|
|
http::{HeaderMap, Request, StatusCode, Uri},
|
|
response::Response,
|
|
routing::{any, get},
|
|
Json, Router,
|
|
};
|
|
use clap::Parser;
|
|
use reqwest::{Client, Url};
|
|
use serde::{Deserialize, Serialize};
|
|
use tonic::transport::{Certificate, Channel, ClientTlsConfig, Endpoint, Identity};
|
|
use tonic::Request as TonicRequest;
|
|
use tracing::{info, warn};
|
|
use tracing_subscriber::EnvFilter;
|
|
use uuid::Uuid;
|
|
|
|
const DEFAULT_REQUEST_ID_HEADER: &str = "x-request-id";
|
|
const PHOTON_AUTH_TOKEN_HEADER: &str = "x-photon-auth-token";
|
|
const DEFAULT_AUTH_TIMEOUT_MS: u64 = 500;
|
|
const DEFAULT_CREDIT_TIMEOUT_MS: u64 = 500;
|
|
const DEFAULT_UPSTREAM_TIMEOUT_MS: u64 = 10_000;
|
|
const RESERVED_AUTH_HEADERS: [&str; 10] = [
|
|
"authorization",
|
|
"x-photon-auth-token",
|
|
"x-subject-id",
|
|
"x-org-id",
|
|
"x-project-id",
|
|
"x-roles",
|
|
"x-scopes",
|
|
"x-iam-session-id",
|
|
"x-iam-principal-kind",
|
|
"x-iam-auth-method",
|
|
];
|
|
const AUTH_PROVIDER_BLOCK_HEADERS: [&str; 17] = [
|
|
"authorization",
|
|
"x-photon-auth-token",
|
|
"x-subject-id",
|
|
"x-org-id",
|
|
"x-project-id",
|
|
"x-roles",
|
|
"x-scopes",
|
|
"proxy-authorization",
|
|
"cookie",
|
|
"set-cookie",
|
|
"host",
|
|
"connection",
|
|
"upgrade",
|
|
"keep-alive",
|
|
"te",
|
|
"trailer",
|
|
"transfer-encoding",
|
|
];
|
|
|
|
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
|
#[serde(rename_all = "snake_case")]
|
|
enum PolicyMode {
|
|
Disabled,
|
|
Optional,
|
|
Required,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
|
#[serde(rename_all = "snake_case")]
|
|
enum CommitPolicy {
|
|
Success,
|
|
Always,
|
|
Never,
|
|
}
|
|
|
|
fn default_policy_mode() -> PolicyMode {
|
|
PolicyMode::Required
|
|
}
|
|
|
|
fn default_commit_policy() -> CommitPolicy {
|
|
CommitPolicy::Success
|
|
}
|
|
|
|
fn default_credit_units() -> u64 {
|
|
1
|
|
}
|
|
|
|
fn default_upstream_timeout_ms() -> u64 {
|
|
DEFAULT_UPSTREAM_TIMEOUT_MS
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct TlsConfig {
|
|
#[serde(default)]
|
|
ca_file: Option<String>,
|
|
#[serde(default)]
|
|
cert_file: Option<String>,
|
|
#[serde(default)]
|
|
key_file: Option<String>,
|
|
#[serde(default)]
|
|
domain_name: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct AuthProviderConfig {
|
|
name: String,
|
|
#[serde(rename = "type")]
|
|
provider_type: String,
|
|
endpoint: String,
|
|
#[serde(default)]
|
|
timeout_ms: Option<u64>,
|
|
#[serde(default)]
|
|
tls: Option<TlsConfig>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct CreditProviderConfig {
|
|
name: String,
|
|
#[serde(rename = "type")]
|
|
provider_type: String,
|
|
endpoint: String,
|
|
#[serde(default)]
|
|
timeout_ms: Option<u64>,
|
|
#[serde(default)]
|
|
tls: Option<TlsConfig>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct RouteAuthConfig {
|
|
provider: String,
|
|
#[serde(default = "default_policy_mode")]
|
|
mode: PolicyMode,
|
|
#[serde(default)]
|
|
fail_open: bool,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct RouteCreditConfig {
|
|
provider: String,
|
|
#[serde(default = "default_policy_mode")]
|
|
mode: PolicyMode,
|
|
#[serde(default = "default_credit_units")]
|
|
units: u64,
|
|
#[serde(default)]
|
|
fail_open: bool,
|
|
#[serde(default = "default_commit_policy")]
|
|
commit_on: CommitPolicy,
|
|
#[serde(default)]
|
|
allow_header_subject: bool,
|
|
#[serde(default)]
|
|
attributes: HashMap<String, String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct RouteConfig {
|
|
name: String,
|
|
path_prefix: String,
|
|
upstream: String,
|
|
#[serde(default)]
|
|
strip_prefix: bool,
|
|
#[serde(default)]
|
|
timeout_ms: Option<u64>,
|
|
#[serde(default)]
|
|
auth: Option<RouteAuthConfig>,
|
|
#[serde(default)]
|
|
credit: Option<RouteCreditConfig>,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
struct Route {
|
|
config: RouteConfig,
|
|
upstream: Url,
|
|
upstream_base_path: String,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct ServerConfig {
|
|
#[serde(default = "default_http_addr")]
|
|
http_addr: SocketAddr,
|
|
#[serde(default = "default_log_level")]
|
|
log_level: String,
|
|
#[serde(default = "default_max_body_bytes")]
|
|
max_body_bytes: usize,
|
|
#[serde(default = "default_max_response_bytes")]
|
|
max_response_bytes: usize,
|
|
#[serde(default = "default_upstream_timeout_ms")]
|
|
upstream_timeout_ms: u64,
|
|
#[serde(default)]
|
|
trust_forwarded_headers: bool,
|
|
#[serde(default)]
|
|
auth_providers: Vec<AuthProviderConfig>,
|
|
#[serde(default)]
|
|
credit_providers: Vec<CreditProviderConfig>,
|
|
#[serde(default)]
|
|
routes: Vec<RouteConfig>,
|
|
}
|
|
|
|
impl Default for ServerConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
http_addr: default_http_addr(),
|
|
log_level: default_log_level(),
|
|
max_body_bytes: default_max_body_bytes(),
|
|
max_response_bytes: default_max_response_bytes(),
|
|
upstream_timeout_ms: default_upstream_timeout_ms(),
|
|
trust_forwarded_headers: false,
|
|
auth_providers: Vec::new(),
|
|
credit_providers: Vec::new(),
|
|
routes: Vec::new(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Parser)]
|
|
#[command(author, version, about)]
|
|
struct Args {
|
|
/// Configuration file path
|
|
#[arg(short, long, default_value = "apigateway.toml")]
|
|
config: PathBuf,
|
|
|
|
/// HTTP listen address (overrides config)
|
|
#[arg(long)]
|
|
http_addr: Option<String>,
|
|
|
|
/// Log level (overrides config)
|
|
#[arg(short, long)]
|
|
log_level: Option<String>,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
struct ServerState {
|
|
routes: Vec<Route>,
|
|
client: Client,
|
|
upstream_timeout: Duration,
|
|
max_body_bytes: usize,
|
|
max_response_bytes: usize,
|
|
auth_providers: HashMap<String, AuthProvider>,
|
|
credit_providers: HashMap<String, CreditProvider>,
|
|
trust_forwarded_headers: bool,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
struct GrpcAuthProvider {
|
|
channel: Channel,
|
|
timeout: Duration,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
enum AuthProvider {
|
|
Grpc(GrpcAuthProvider),
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
struct GrpcCreditProvider {
|
|
channel: Channel,
|
|
timeout: Duration,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
enum CreditProvider {
|
|
Grpc(GrpcCreditProvider),
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
struct SubjectInfo {
|
|
subject_id: String,
|
|
org_id: String,
|
|
project_id: String,
|
|
roles: Vec<String>,
|
|
scopes: Vec<String>,
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
struct CreditSubject {
|
|
subject_id: String,
|
|
org_id: String,
|
|
project_id: String,
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
struct AuthDecision {
|
|
allow: bool,
|
|
subject: Option<SubjectInfo>,
|
|
headers: HashMap<String, String>,
|
|
reason: Option<String>,
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
struct CreditDecision {
|
|
allow: bool,
|
|
reservation_id: String,
|
|
reason: Option<String>,
|
|
}
|
|
|
|
#[derive(Clone, Debug, Default)]
|
|
struct AuthOutcome {
|
|
subject: Option<SubjectInfo>,
|
|
headers: HashMap<String, String>,
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
struct CreditReservation {
|
|
provider: String,
|
|
reservation_id: String,
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
struct RequestContext {
|
|
request_id: String,
|
|
method: String,
|
|
path: String,
|
|
raw_query: String,
|
|
headers: HashMap<String, String>,
|
|
client_ip: String,
|
|
route_name: String,
|
|
}
|
|
|
|
fn default_http_addr() -> SocketAddr {
|
|
"127.0.0.1:8080"
|
|
.parse()
|
|
.expect("invalid default HTTP address")
|
|
}
|
|
|
|
fn default_log_level() -> String {
|
|
"info".to_string()
|
|
}
|
|
|
|
fn default_max_body_bytes() -> usize {
|
|
16 * 1024 * 1024
|
|
}
|
|
|
|
fn default_max_response_bytes() -> usize {
|
|
default_max_body_bytes()
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
let args = Args::parse();
|
|
|
|
let mut used_default_config = false;
|
|
let mut config = if args.config.exists() {
|
|
let contents = tokio::fs::read_to_string(&args.config).await?;
|
|
toml::from_str(&contents)?
|
|
} else {
|
|
used_default_config = true;
|
|
ServerConfig::default()
|
|
};
|
|
|
|
if let Some(http_addr) = args.http_addr {
|
|
config.http_addr = http_addr.parse()?;
|
|
}
|
|
if let Some(log_level) = args.log_level {
|
|
config.log_level = log_level;
|
|
}
|
|
|
|
tracing_subscriber::fmt()
|
|
.with_env_filter(
|
|
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&config.log_level)),
|
|
)
|
|
.init();
|
|
|
|
if used_default_config {
|
|
info!("Config file not found: {}, using defaults", args.config.display());
|
|
}
|
|
|
|
let routes = build_routes(config.routes)?;
|
|
let auth_providers = build_auth_providers(config.auth_providers).await?;
|
|
let credit_providers = build_credit_providers(config.credit_providers).await?;
|
|
let upstream_timeout = Duration::from_millis(config.upstream_timeout_ms);
|
|
let client = Client::builder().build()?;
|
|
|
|
info!("Starting API gateway");
|
|
info!(" HTTP: {}", config.http_addr);
|
|
info!(" Max body bytes: {}", config.max_body_bytes);
|
|
info!(" Max response bytes: {}", config.max_response_bytes);
|
|
|
|
if !routes.is_empty() {
|
|
info!("Configured {} routes", routes.len());
|
|
} else {
|
|
warn!("No routes configured; proxy will return 404s");
|
|
}
|
|
|
|
if auth_providers.is_empty() {
|
|
warn!("No auth providers configured");
|
|
}
|
|
if credit_providers.is_empty() {
|
|
warn!("No credit providers configured");
|
|
}
|
|
|
|
let state = Arc::new(ServerState {
|
|
routes,
|
|
client,
|
|
upstream_timeout,
|
|
max_body_bytes: config.max_body_bytes,
|
|
max_response_bytes: config.max_response_bytes,
|
|
auth_providers,
|
|
credit_providers,
|
|
trust_forwarded_headers: config.trust_forwarded_headers,
|
|
});
|
|
|
|
let app = Router::new()
|
|
.route("/", get(index))
|
|
.route("/health", get(health))
|
|
.route("/routes", get(list_routes))
|
|
.route("/*path", any(proxy))
|
|
.with_state(state);
|
|
|
|
let listener = tokio::net::TcpListener::bind(config.http_addr).await?;
|
|
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()).await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn index() -> &'static str {
|
|
"apigateway-server running"
|
|
}
|
|
|
|
async fn health() -> Json<serde_json::Value> {
|
|
Json(serde_json::json!({"status": "ok"}))
|
|
}
|
|
|
|
async fn list_routes(State(state): State<Arc<ServerState>>) -> Json<Vec<RouteConfig>> {
|
|
Json(state.routes.iter().map(|route| route.config.clone()).collect())
|
|
}
|
|
|
|
async fn proxy(
|
|
State(state): State<Arc<ServerState>>,
|
|
ConnectInfo(remote_addr): ConnectInfo<SocketAddr>,
|
|
request: Request<Body>,
|
|
) -> Result<Response<Body>, StatusCode> {
|
|
let path = request.uri().path();
|
|
let route = match_route(&state.routes, path)
|
|
.ok_or(StatusCode::NOT_FOUND)?
|
|
.clone();
|
|
let request_id = extract_request_id(request.headers());
|
|
|
|
let context = RequestContext {
|
|
request_id: request_id.clone(),
|
|
method: request.method().to_string(),
|
|
path: request.uri().path().to_string(),
|
|
raw_query: request.uri().query().unwrap_or("").to_string(),
|
|
headers: headers_to_map(request.headers()),
|
|
client_ip: extract_client_ip(
|
|
request.headers(),
|
|
remote_addr,
|
|
state.trust_forwarded_headers,
|
|
),
|
|
route_name: route.config.name.clone(),
|
|
};
|
|
|
|
let auth_token = extract_auth_token(request.headers());
|
|
let forward_client_auth_headers = route.config.auth.is_none();
|
|
|
|
let auth_outcome = enforce_auth(&state, &route, &context, auth_token).await?;
|
|
let credit_reservation =
|
|
enforce_credit(&state, &route, &context, auth_outcome.subject.as_ref()).await?;
|
|
|
|
let target_url = build_upstream_url(&route, request.uri())?;
|
|
|
|
let request_timeout =
|
|
Duration::from_millis(route.config.timeout_ms.unwrap_or(state.upstream_timeout.as_millis() as u64));
|
|
let mut builder = state
|
|
.client
|
|
.request(request.method().clone(), target_url)
|
|
.timeout(request_timeout);
|
|
for (name, value) in request.headers().iter() {
|
|
if name == axum::http::header::HOST || name == axum::http::header::CONNECTION {
|
|
continue;
|
|
}
|
|
if is_reserved_auth_header(name) {
|
|
if forward_client_auth_headers && should_preserve_client_auth_header(name.as_str()) {
|
|
builder = builder.header(name, value);
|
|
}
|
|
continue;
|
|
}
|
|
builder = builder.header(name, value);
|
|
}
|
|
|
|
builder = builder.header(DEFAULT_REQUEST_ID_HEADER, request_id.clone());
|
|
builder = apply_auth_headers(builder, &auth_outcome);
|
|
|
|
let body_bytes = to_bytes(request.into_body(), state.max_body_bytes)
|
|
.await
|
|
.map_err(|_| StatusCode::PAYLOAD_TOO_LARGE)?;
|
|
|
|
let response = match builder.body(body_bytes).send().await {
|
|
Ok(response) => response,
|
|
Err(_) => {
|
|
finalize_credit_abort(&state, &route, credit_reservation).await;
|
|
return Err(StatusCode::BAD_GATEWAY);
|
|
}
|
|
};
|
|
|
|
let status = response.status();
|
|
if let Some(content_length) = response.content_length() {
|
|
if state.max_response_bytes > 0 && content_length as usize > state.max_response_bytes {
|
|
finalize_credit_abort(&state, &route, credit_reservation).await;
|
|
return Err(StatusCode::PAYLOAD_TOO_LARGE);
|
|
}
|
|
}
|
|
|
|
let mut response_builder = Response::builder().status(status);
|
|
let headers = response_builder
|
|
.headers_mut()
|
|
.ok_or(StatusCode::BAD_GATEWAY)?;
|
|
|
|
for (name, value) in response.headers().iter() {
|
|
if name == axum::http::header::CONNECTION {
|
|
continue;
|
|
}
|
|
headers.insert(name, value.clone());
|
|
}
|
|
|
|
let body = match response.bytes().await {
|
|
Ok(body) => body,
|
|
Err(_) => {
|
|
finalize_credit_abort(&state, &route, credit_reservation).await;
|
|
return Err(StatusCode::BAD_GATEWAY);
|
|
}
|
|
};
|
|
if state.max_response_bytes > 0 && body.len() > state.max_response_bytes {
|
|
finalize_credit_abort(&state, &route, credit_reservation).await;
|
|
return Err(StatusCode::PAYLOAD_TOO_LARGE);
|
|
}
|
|
|
|
finalize_credit(&state, &route, credit_reservation, status).await;
|
|
|
|
response_builder
|
|
.body(Body::from(body))
|
|
.map_err(|_| StatusCode::BAD_GATEWAY)
|
|
}
|
|
|
|
async fn enforce_auth(
|
|
state: &ServerState,
|
|
route: &Route,
|
|
context: &RequestContext,
|
|
token: Option<String>,
|
|
) -> Result<AuthOutcome, StatusCode> {
|
|
let Some(auth_cfg) = &route.config.auth else {
|
|
return Ok(AuthOutcome::default());
|
|
};
|
|
|
|
if auth_cfg.mode == PolicyMode::Disabled {
|
|
return Ok(AuthOutcome::default());
|
|
}
|
|
|
|
let decision = authorize_request(state, auth_cfg, context, token).await;
|
|
apply_auth_mode(auth_cfg.mode, auth_cfg.fail_open, decision)
|
|
}
|
|
|
|
fn apply_auth_mode(
|
|
mode: PolicyMode,
|
|
fail_open: bool,
|
|
decision: Result<AuthDecision, StatusCode>,
|
|
) -> Result<AuthOutcome, StatusCode> {
|
|
match mode {
|
|
PolicyMode::Disabled => Ok(AuthOutcome::default()),
|
|
PolicyMode::Optional => match decision {
|
|
Ok(decision) if decision.allow => Ok(AuthOutcome {
|
|
subject: decision.subject,
|
|
headers: decision.headers,
|
|
}),
|
|
Ok(decision) => {
|
|
if let Some(reason) = decision.reason {
|
|
warn!("Auth denied (optional mode): {}", reason);
|
|
}
|
|
Ok(AuthOutcome::default())
|
|
}
|
|
Err(err) => {
|
|
warn!("Auth provider error (optional mode): {}", err);
|
|
Ok(AuthOutcome::default())
|
|
}
|
|
},
|
|
PolicyMode::Required => match decision {
|
|
Ok(decision) if decision.allow => Ok(AuthOutcome {
|
|
subject: decision.subject,
|
|
headers: decision.headers,
|
|
}),
|
|
Ok(decision) => {
|
|
if let Some(reason) = decision.reason {
|
|
warn!("Auth denied (required mode): {}", reason);
|
|
}
|
|
Err(StatusCode::FORBIDDEN)
|
|
}
|
|
Err(err) => {
|
|
warn!("Auth provider error (required mode): {}", err);
|
|
if fail_open {
|
|
Ok(AuthOutcome::default())
|
|
} else {
|
|
Err(StatusCode::BAD_GATEWAY)
|
|
}
|
|
}
|
|
},
|
|
}
|
|
}
|
|
|
|
async fn enforce_credit(
|
|
state: &ServerState,
|
|
route: &Route,
|
|
context: &RequestContext,
|
|
subject: Option<&SubjectInfo>,
|
|
) -> Result<Option<CreditReservation>, StatusCode> {
|
|
let Some(credit_cfg) = &route.config.credit else {
|
|
return Ok(None);
|
|
};
|
|
|
|
if credit_cfg.mode == PolicyMode::Disabled {
|
|
return Ok(None);
|
|
}
|
|
|
|
let credit_subject = resolve_credit_subject(context, subject, credit_cfg.allow_header_subject);
|
|
if credit_subject.is_none() {
|
|
if credit_cfg.mode == PolicyMode::Required {
|
|
return Err(StatusCode::UNAUTHORIZED);
|
|
}
|
|
warn!("Credit skipped: missing org/project scope");
|
|
return Ok(None);
|
|
}
|
|
|
|
let decision = reserve_credit(
|
|
state,
|
|
credit_cfg,
|
|
context,
|
|
credit_subject.as_ref().expect("credit subject resolved"),
|
|
)
|
|
.await;
|
|
apply_credit_mode(credit_cfg.mode, credit_cfg.fail_open, decision)
|
|
.map(|decision| {
|
|
decision.map(|decision| CreditReservation {
|
|
provider: credit_cfg.provider.clone(),
|
|
reservation_id: decision.reservation_id,
|
|
})
|
|
})
|
|
}
|
|
|
|
fn apply_credit_mode(
|
|
mode: PolicyMode,
|
|
fail_open: bool,
|
|
decision: Result<CreditDecision, StatusCode>,
|
|
) -> Result<Option<CreditDecision>, StatusCode> {
|
|
match mode {
|
|
PolicyMode::Disabled => Ok(None),
|
|
PolicyMode::Optional => match decision {
|
|
Ok(decision) if decision.allow => Ok(Some(decision)),
|
|
Ok(decision) => {
|
|
if let Some(reason) = decision.reason {
|
|
warn!("Credit denied (optional mode): {}", reason);
|
|
}
|
|
Ok(None)
|
|
}
|
|
Err(err) => {
|
|
warn!("Credit provider error (optional mode): {}", err);
|
|
Ok(None)
|
|
}
|
|
},
|
|
PolicyMode::Required => match decision {
|
|
Ok(decision) if decision.allow => Ok(Some(decision)),
|
|
Ok(decision) => {
|
|
if let Some(reason) = decision.reason {
|
|
warn!("Credit denied (required mode): {}", reason);
|
|
}
|
|
Err(StatusCode::PAYMENT_REQUIRED)
|
|
}
|
|
Err(err) => {
|
|
warn!("Credit provider error (required mode): {}", err);
|
|
if fail_open {
|
|
Ok(None)
|
|
} else {
|
|
Err(StatusCode::BAD_GATEWAY)
|
|
}
|
|
}
|
|
},
|
|
}
|
|
}
|
|
|
|
async fn authorize_request(
|
|
state: &ServerState,
|
|
auth_cfg: &RouteAuthConfig,
|
|
context: &RequestContext,
|
|
token: Option<String>,
|
|
) -> Result<AuthDecision, StatusCode> {
|
|
let provider = state
|
|
.auth_providers
|
|
.get(&auth_cfg.provider)
|
|
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
|
|
|
|
match provider {
|
|
AuthProvider::Grpc(provider) => {
|
|
let mut client = GatewayAuthServiceClient::new(provider.channel.clone());
|
|
let mut request = TonicRequest::new(AuthorizeRequest {
|
|
request_id: context.request_id.clone(),
|
|
token: token.unwrap_or_default(),
|
|
method: context.method.clone(),
|
|
path: context.path.clone(),
|
|
raw_query: context.raw_query.clone(),
|
|
headers: context.headers.clone(),
|
|
client_ip: context.client_ip.clone(),
|
|
route_name: context.route_name.clone(),
|
|
});
|
|
request.set_timeout(provider.timeout);
|
|
|
|
let response = client
|
|
.authorize(request)
|
|
.await
|
|
.map_err(|_| StatusCode::BAD_GATEWAY)?
|
|
.into_inner();
|
|
|
|
let subject = response.subject.map(|subject| SubjectInfo {
|
|
subject_id: subject.subject_id,
|
|
org_id: subject.org_id,
|
|
project_id: subject.project_id,
|
|
roles: subject.roles,
|
|
scopes: subject.scopes,
|
|
});
|
|
|
|
Ok(AuthDecision {
|
|
allow: response.allow,
|
|
subject,
|
|
headers: response.headers,
|
|
reason: if response.reason.is_empty() {
|
|
None
|
|
} else {
|
|
Some(response.reason)
|
|
},
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
fn resolve_credit_subject(
|
|
context: &RequestContext,
|
|
subject: Option<&SubjectInfo>,
|
|
allow_header_subject: bool,
|
|
) -> Option<CreditSubject> {
|
|
if let Some(subject) = subject {
|
|
return Some(CreditSubject {
|
|
subject_id: subject.subject_id.clone(),
|
|
org_id: subject.org_id.clone(),
|
|
project_id: subject.project_id.clone(),
|
|
});
|
|
}
|
|
|
|
if !allow_header_subject {
|
|
return None;
|
|
}
|
|
|
|
let org_id = context.headers.get("x-org-id")?.trim();
|
|
let project_id = context.headers.get("x-project-id")?.trim();
|
|
if org_id.is_empty() || project_id.is_empty() {
|
|
return None;
|
|
}
|
|
|
|
let subject_id = context
|
|
.headers
|
|
.get("x-subject-id")
|
|
.map(|value| value.trim().to_string())
|
|
.unwrap_or_default();
|
|
|
|
Some(CreditSubject {
|
|
subject_id,
|
|
org_id: org_id.to_string(),
|
|
project_id: project_id.to_string(),
|
|
})
|
|
}
|
|
|
|
async fn reserve_credit(
|
|
state: &ServerState,
|
|
credit_cfg: &RouteCreditConfig,
|
|
context: &RequestContext,
|
|
credit_subject: &CreditSubject,
|
|
) -> Result<CreditDecision, StatusCode> {
|
|
let provider = state
|
|
.credit_providers
|
|
.get(&credit_cfg.provider)
|
|
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
|
|
|
|
let subject_id = credit_subject.subject_id.clone();
|
|
let org_id = credit_subject.org_id.clone();
|
|
let project_id = credit_subject.project_id.clone();
|
|
|
|
match provider {
|
|
CreditProvider::Grpc(provider) => {
|
|
let mut client = GatewayCreditServiceClient::new(provider.channel.clone());
|
|
let mut request = TonicRequest::new(CreditReserveRequest {
|
|
request_id: context.request_id.clone(),
|
|
subject_id,
|
|
org_id,
|
|
project_id,
|
|
route_name: context.route_name.clone(),
|
|
method: context.method.clone(),
|
|
path: context.path.clone(),
|
|
raw_query: context.raw_query.clone(),
|
|
units: credit_cfg.units,
|
|
attributes: credit_cfg.attributes.clone(),
|
|
});
|
|
request.set_timeout(provider.timeout);
|
|
|
|
let response = client
|
|
.reserve(request)
|
|
.await
|
|
.map_err(|_| StatusCode::BAD_GATEWAY)?
|
|
.into_inner();
|
|
|
|
Ok(CreditDecision {
|
|
allow: response.allow,
|
|
reservation_id: response.reservation_id,
|
|
reason: if response.reason.is_empty() {
|
|
None
|
|
} else {
|
|
Some(response.reason)
|
|
},
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn finalize_credit(
|
|
state: &ServerState,
|
|
route: &Route,
|
|
reservation: Option<CreditReservation>,
|
|
status: reqwest::StatusCode,
|
|
) {
|
|
let Some(credit_cfg) = &route.config.credit else {
|
|
return;
|
|
};
|
|
let Some(reservation) = reservation else {
|
|
return;
|
|
};
|
|
|
|
match credit_cfg.commit_on {
|
|
CommitPolicy::Never => return,
|
|
CommitPolicy::Always => {
|
|
if let Err(err) = commit_credit(state, credit_cfg, &reservation).await {
|
|
warn!("Failed to commit credit reservation {}: {}", reservation.reservation_id, err);
|
|
}
|
|
}
|
|
CommitPolicy::Success => {
|
|
if status.is_success() || status.is_redirection() {
|
|
if let Err(err) = commit_credit(state, credit_cfg, &reservation).await {
|
|
warn!("Failed to commit credit reservation {}: {}", reservation.reservation_id, err);
|
|
}
|
|
} else if let Err(err) = rollback_credit(state, credit_cfg, &reservation).await {
|
|
warn!(
|
|
"Failed to rollback credit reservation {}: {}",
|
|
reservation.reservation_id, err
|
|
);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn finalize_credit_abort(
|
|
state: &ServerState,
|
|
route: &Route,
|
|
reservation: Option<CreditReservation>,
|
|
) {
|
|
let Some(credit_cfg) = &route.config.credit else {
|
|
return;
|
|
};
|
|
let Some(reservation) = reservation else {
|
|
return;
|
|
};
|
|
|
|
if credit_cfg.commit_on == CommitPolicy::Never {
|
|
return;
|
|
}
|
|
|
|
if let Err(err) = rollback_credit(state, credit_cfg, &reservation).await {
|
|
warn!(
|
|
"Failed to rollback credit reservation {} after delivery failure: {}",
|
|
reservation.reservation_id, err
|
|
);
|
|
}
|
|
}
|
|
|
|
async fn commit_credit(
|
|
state: &ServerState,
|
|
credit_cfg: &RouteCreditConfig,
|
|
reservation: &CreditReservation,
|
|
) -> Result<(), StatusCode> {
|
|
let provider = state
|
|
.credit_providers
|
|
.get(&reservation.provider)
|
|
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
|
|
|
|
match provider {
|
|
CreditProvider::Grpc(provider) => {
|
|
let mut client = GatewayCreditServiceClient::new(provider.channel.clone());
|
|
let mut request = TonicRequest::new(CreditCommitRequest {
|
|
reservation_id: reservation.reservation_id.clone(),
|
|
units: credit_cfg.units,
|
|
});
|
|
request.set_timeout(provider.timeout);
|
|
let response = client
|
|
.commit(request)
|
|
.await
|
|
.map_err(|_| StatusCode::BAD_GATEWAY)?
|
|
.into_inner();
|
|
if response.success {
|
|
Ok(())
|
|
} else {
|
|
Err(StatusCode::BAD_GATEWAY)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn rollback_credit(
|
|
state: &ServerState,
|
|
_credit_cfg: &RouteCreditConfig,
|
|
reservation: &CreditReservation,
|
|
) -> Result<(), StatusCode> {
|
|
let provider = state
|
|
.credit_providers
|
|
.get(&reservation.provider)
|
|
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
|
|
|
|
match provider {
|
|
CreditProvider::Grpc(provider) => {
|
|
let mut client = GatewayCreditServiceClient::new(provider.channel.clone());
|
|
let mut request = TonicRequest::new(CreditRollbackRequest {
|
|
reservation_id: reservation.reservation_id.clone(),
|
|
});
|
|
request.set_timeout(provider.timeout);
|
|
let response = client
|
|
.rollback(request)
|
|
.await
|
|
.map_err(|_| StatusCode::BAD_GATEWAY)?
|
|
.into_inner();
|
|
if response.success {
|
|
Ok(())
|
|
} else {
|
|
Err(StatusCode::BAD_GATEWAY)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn apply_auth_headers(
|
|
mut builder: reqwest::RequestBuilder,
|
|
outcome: &AuthOutcome,
|
|
) -> reqwest::RequestBuilder {
|
|
for (key, value) in &outcome.headers {
|
|
if !should_forward_auth_header(key) {
|
|
continue;
|
|
}
|
|
builder = builder.header(key, value);
|
|
}
|
|
|
|
if let Some(subject) = &outcome.subject {
|
|
builder = builder
|
|
.header("x-subject-id", &subject.subject_id)
|
|
.header("x-org-id", &subject.org_id)
|
|
.header("x-project-id", &subject.project_id);
|
|
if !subject.roles.is_empty() {
|
|
builder = builder.header("x-roles", subject.roles.join(","));
|
|
}
|
|
if !subject.scopes.is_empty() {
|
|
builder = builder.header("x-scopes", subject.scopes.join(","));
|
|
}
|
|
}
|
|
|
|
builder
|
|
}
|
|
|
|
async fn build_client_tls_config(
|
|
tls: &Option<TlsConfig>,
|
|
) -> Result<Option<ClientTlsConfig>, Box<dyn std::error::Error>> {
|
|
let Some(tls) = tls else {
|
|
return Ok(None);
|
|
};
|
|
|
|
let mut tls_config = ClientTlsConfig::new();
|
|
|
|
if let Some(ca_file) = &tls.ca_file {
|
|
let ca = tokio::fs::read(ca_file).await?;
|
|
tls_config = tls_config.ca_certificate(Certificate::from_pem(ca));
|
|
}
|
|
|
|
match (&tls.cert_file, &tls.key_file) {
|
|
(Some(cert_file), Some(key_file)) => {
|
|
let cert = tokio::fs::read(cert_file).await?;
|
|
let key = tokio::fs::read(key_file).await?;
|
|
tls_config = tls_config.identity(Identity::from_pem(cert, key));
|
|
}
|
|
(None, None) => {}
|
|
_ => {
|
|
return Err(config_error("tls requires both cert_file and key_file").into());
|
|
}
|
|
}
|
|
|
|
if let Some(domain) = &tls.domain_name {
|
|
tls_config = tls_config.domain_name(domain);
|
|
}
|
|
|
|
Ok(Some(tls_config))
|
|
}
|
|
|
|
async fn build_auth_providers(
|
|
configs: Vec<AuthProviderConfig>,
|
|
) -> Result<HashMap<String, AuthProvider>, Box<dyn std::error::Error>> {
|
|
let mut providers = HashMap::new();
|
|
|
|
for config in configs {
|
|
let provider_type = normalize_name(&config.provider_type);
|
|
if providers.contains_key(&config.name) {
|
|
return Err(config_error(format!(
|
|
"duplicate auth provider name {}",
|
|
config.name
|
|
))
|
|
.into());
|
|
}
|
|
|
|
match provider_type.as_str() {
|
|
"grpc" => {
|
|
let mut endpoint = Endpoint::from_shared(config.endpoint.clone())?
|
|
.connect_timeout(Duration::from_millis(
|
|
config.timeout_ms.unwrap_or(DEFAULT_AUTH_TIMEOUT_MS),
|
|
))
|
|
.timeout(Duration::from_millis(
|
|
config.timeout_ms.unwrap_or(DEFAULT_AUTH_TIMEOUT_MS),
|
|
));
|
|
if let Some(tls) = build_client_tls_config(&config.tls).await? {
|
|
endpoint = endpoint.tls_config(tls)?;
|
|
}
|
|
let channel = endpoint.connect().await?;
|
|
let timeout =
|
|
Duration::from_millis(config.timeout_ms.unwrap_or(DEFAULT_AUTH_TIMEOUT_MS));
|
|
providers.insert(
|
|
config.name.clone(),
|
|
AuthProvider::Grpc(GrpcAuthProvider {
|
|
channel,
|
|
timeout,
|
|
}),
|
|
);
|
|
}
|
|
_ => {
|
|
return Err(config_error(format!(
|
|
"unsupported auth provider type {}",
|
|
config.provider_type
|
|
))
|
|
.into());
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(providers)
|
|
}
|
|
|
|
async fn build_credit_providers(
|
|
configs: Vec<CreditProviderConfig>,
|
|
) -> Result<HashMap<String, CreditProvider>, Box<dyn std::error::Error>> {
|
|
let mut providers = HashMap::new();
|
|
|
|
for config in configs {
|
|
let provider_type = normalize_name(&config.provider_type);
|
|
if providers.contains_key(&config.name) {
|
|
return Err(config_error(format!(
|
|
"duplicate credit provider name {}",
|
|
config.name
|
|
))
|
|
.into());
|
|
}
|
|
|
|
match provider_type.as_str() {
|
|
"grpc" => {
|
|
let mut endpoint = Endpoint::from_shared(config.endpoint.clone())?
|
|
.connect_timeout(Duration::from_millis(
|
|
config
|
|
.timeout_ms
|
|
.unwrap_or(DEFAULT_CREDIT_TIMEOUT_MS),
|
|
))
|
|
.timeout(Duration::from_millis(
|
|
config
|
|
.timeout_ms
|
|
.unwrap_or(DEFAULT_CREDIT_TIMEOUT_MS),
|
|
));
|
|
|
|
if let Some(tls) = build_client_tls_config(&config.tls).await? {
|
|
endpoint = endpoint.tls_config(tls)?;
|
|
}
|
|
|
|
let channel = endpoint.connect().await?;
|
|
let timeout = Duration::from_millis(
|
|
config
|
|
.timeout_ms
|
|
.unwrap_or(DEFAULT_CREDIT_TIMEOUT_MS),
|
|
);
|
|
providers.insert(
|
|
config.name.clone(),
|
|
CreditProvider::Grpc(GrpcCreditProvider {
|
|
channel,
|
|
timeout,
|
|
}),
|
|
);
|
|
}
|
|
_ => {
|
|
return Err(config_error(format!(
|
|
"unsupported credit provider type {}",
|
|
config.provider_type
|
|
))
|
|
.into());
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(providers)
|
|
}
|
|
|
|
fn build_routes(configs: Vec<RouteConfig>) -> Result<Vec<Route>, Box<dyn std::error::Error>> {
|
|
let mut routes = Vec::new();
|
|
|
|
for mut config in configs {
|
|
if config.name.trim().is_empty() {
|
|
return Err(config_error("route name is required").into());
|
|
}
|
|
let path_prefix = normalize_path_prefix(&config.path_prefix);
|
|
config.path_prefix = path_prefix;
|
|
|
|
let upstream = Url::parse(&config.upstream)?;
|
|
if upstream.scheme() != "http" && upstream.scheme() != "https" {
|
|
return Err(config_error(format!(
|
|
"route {} upstream must be http or https",
|
|
config.name
|
|
))
|
|
.into());
|
|
}
|
|
if upstream.host_str().is_none() {
|
|
return Err(config_error(format!(
|
|
"route {} upstream must include host",
|
|
config.name
|
|
))
|
|
.into());
|
|
}
|
|
let upstream_base_path = normalize_upstream_base_path(upstream.path());
|
|
|
|
routes.push(Route {
|
|
config,
|
|
upstream,
|
|
upstream_base_path,
|
|
});
|
|
}
|
|
|
|
routes.sort_by(|a, b| b.config.path_prefix.len().cmp(&a.config.path_prefix.len()));
|
|
Ok(routes)
|
|
}
|
|
|
|
fn config_error(message: impl Into<String>) -> io::Error {
|
|
io::Error::new(io::ErrorKind::InvalidInput, message.into())
|
|
}
|
|
|
|
fn normalize_name(value: &str) -> String {
|
|
value.trim().to_lowercase().replace('-', "_")
|
|
}
|
|
|
|
fn extract_request_id(headers: &HeaderMap) -> String {
|
|
headers
|
|
.get(DEFAULT_REQUEST_ID_HEADER)
|
|
.and_then(|value| value.to_str().ok())
|
|
.map(|value| value.to_string())
|
|
.unwrap_or_else(|| Uuid::new_v4().to_string())
|
|
}
|
|
|
|
fn extract_client_ip(
|
|
headers: &HeaderMap,
|
|
remote_addr: SocketAddr,
|
|
trust_forwarded_headers: bool,
|
|
) -> String {
|
|
if trust_forwarded_headers {
|
|
if let Some(value) = headers
|
|
.get("x-forwarded-for")
|
|
.and_then(|value| value.to_str().ok())
|
|
.and_then(|value| value.split(',').next())
|
|
{
|
|
let trimmed = value.trim();
|
|
if !trimmed.is_empty() {
|
|
return trimmed.to_string();
|
|
}
|
|
}
|
|
if let Some(value) = headers
|
|
.get("x-real-ip")
|
|
.and_then(|value| value.to_str().ok())
|
|
{
|
|
let trimmed = value.trim();
|
|
if !trimmed.is_empty() {
|
|
return trimmed.to_string();
|
|
}
|
|
}
|
|
}
|
|
|
|
remote_addr.ip().to_string()
|
|
}
|
|
|
|
fn headers_to_map(headers: &HeaderMap) -> HashMap<String, String> {
|
|
let mut map: HashMap<String, String> = HashMap::new();
|
|
for (name, value) in headers.iter() {
|
|
if let Ok(value) = value.to_str() {
|
|
map.entry(name.as_str().to_string())
|
|
.and_modify(|entry| {
|
|
entry.push(',');
|
|
entry.push_str(value);
|
|
})
|
|
.or_insert_with(|| value.to_string());
|
|
}
|
|
}
|
|
map
|
|
}
|
|
|
|
fn extract_auth_token(headers: &HeaderMap) -> Option<String> {
|
|
let auth_header = headers
|
|
.get(axum::http::header::AUTHORIZATION)
|
|
.and_then(|value| value.to_str().ok());
|
|
if let Some(token) = auth_header.and_then(parse_auth_token_value) {
|
|
return Some(token);
|
|
}
|
|
|
|
let photon_header = headers
|
|
.get(PHOTON_AUTH_TOKEN_HEADER)
|
|
.and_then(|value| value.to_str().ok());
|
|
photon_header.and_then(parse_auth_token_value)
|
|
}
|
|
|
|
fn is_reserved_auth_header(name: &axum::http::header::HeaderName) -> bool {
|
|
is_reserved_auth_header_str(name.as_str())
|
|
}
|
|
|
|
fn is_reserved_auth_header_str(name: &str) -> bool {
|
|
let header = name.to_ascii_lowercase();
|
|
RESERVED_AUTH_HEADERS.iter().any(|value| *value == header)
|
|
}
|
|
|
|
fn should_forward_auth_header(name: &str) -> bool {
|
|
let header = name.to_ascii_lowercase();
|
|
if AUTH_PROVIDER_BLOCK_HEADERS
|
|
.iter()
|
|
.any(|value| *value == header)
|
|
{
|
|
return false;
|
|
}
|
|
header.starts_with("x-")
|
|
}
|
|
|
|
fn should_preserve_client_auth_header(name: &str) -> bool {
|
|
let header = name.to_ascii_lowercase();
|
|
header == "authorization" || header == PHOTON_AUTH_TOKEN_HEADER
|
|
}
|
|
|
|
fn parse_auth_token_value(value: &str) -> Option<String> {
|
|
let trimmed = value.trim();
|
|
if trimmed.is_empty() {
|
|
return None;
|
|
}
|
|
|
|
if let Some(token) = parse_bearer_token(trimmed) {
|
|
return Some(token);
|
|
}
|
|
|
|
// Legacy support: allow raw token values without a scheme.
|
|
if trimmed.split_whitespace().count() != 1 {
|
|
return None;
|
|
}
|
|
|
|
Some(trimmed.to_string())
|
|
}
|
|
|
|
fn parse_bearer_token(value: &str) -> Option<String> {
|
|
let mut parts = value.split_whitespace();
|
|
let scheme = parts.next()?;
|
|
if !scheme.eq_ignore_ascii_case("bearer") {
|
|
return None;
|
|
}
|
|
|
|
let token = parts.next()?;
|
|
if parts.next().is_some() {
|
|
return None;
|
|
}
|
|
|
|
Some(token.to_string())
|
|
}
|
|
|
|
fn normalize_path_prefix(prefix: &str) -> String {
|
|
let trimmed = prefix.trim();
|
|
if trimmed.is_empty() {
|
|
return "/".to_string();
|
|
}
|
|
|
|
let mut normalized = if trimmed.starts_with('/') {
|
|
trimmed.to_string()
|
|
} else {
|
|
format!("/{}", trimmed)
|
|
};
|
|
|
|
if normalized.len() > 1 && normalized.ends_with('/') {
|
|
normalized.pop();
|
|
}
|
|
|
|
normalized
|
|
}
|
|
|
|
fn normalize_upstream_base_path(path: &str) -> String {
|
|
let trimmed = path.trim();
|
|
if trimmed.is_empty() || trimmed == "/" {
|
|
"/".to_string()
|
|
} else {
|
|
trimmed.trim_end_matches('/').to_string()
|
|
}
|
|
}
|
|
|
|
fn match_route<'a>(routes: &'a [Route], path: &str) -> Option<&'a Route> {
|
|
routes
|
|
.iter()
|
|
.find(|route| path_matches_prefix(path, &route.config.path_prefix))
|
|
}
|
|
|
|
fn path_matches_prefix(path: &str, prefix: &str) -> bool {
|
|
if prefix == "/" {
|
|
return true;
|
|
}
|
|
|
|
if path == prefix {
|
|
return true;
|
|
}
|
|
|
|
match path.strip_prefix(prefix) {
|
|
Some(stripped) => stripped.starts_with('/'),
|
|
None => false,
|
|
}
|
|
}
|
|
|
|
fn strip_prefix_path(path: &str, prefix: &str) -> String {
|
|
if prefix == "/" {
|
|
return path.to_string();
|
|
}
|
|
|
|
match path.strip_prefix(prefix) {
|
|
Some("") => "/".to_string(),
|
|
Some(stripped) => {
|
|
if stripped.starts_with('/') {
|
|
stripped.to_string()
|
|
} else {
|
|
format!("/{}", stripped)
|
|
}
|
|
}
|
|
None => path.to_string(),
|
|
}
|
|
}
|
|
|
|
fn join_paths(base: &str, path: &str) -> String {
|
|
if base == "/" {
|
|
return path.to_string();
|
|
}
|
|
if path == "/" {
|
|
let trimmed = base.trim_end_matches('/');
|
|
return if trimmed.is_empty() { "/".to_string() } else { trimmed.to_string() };
|
|
}
|
|
|
|
format!(
|
|
"{}/{}",
|
|
base.trim_end_matches('/'),
|
|
path.trim_start_matches('/')
|
|
)
|
|
}
|
|
|
|
fn build_upstream_url(route: &Route, uri: &Uri) -> Result<Url, StatusCode> {
|
|
let path = if route.config.strip_prefix {
|
|
strip_prefix_path(uri.path(), &route.config.path_prefix)
|
|
} else {
|
|
uri.path().to_string()
|
|
};
|
|
|
|
let merged_path = join_paths(&route.upstream_base_path, &path);
|
|
let mut url = route.upstream.clone();
|
|
url.set_path(&merged_path);
|
|
url.set_query(uri.query());
|
|
|
|
Ok(url)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use axum::routing::get;
|
|
use creditservice_api::{CreditServiceImpl, CreditStorage, GatewayCreditServiceImpl};
|
|
use apigateway_api::GatewayCreditServiceServer;
|
|
use creditservice_types::Wallet;
|
|
use iam_api::{GatewayAuthServiceImpl, GatewayAuthServiceServer};
|
|
use iam_authn::{InternalTokenConfig, InternalTokenService, SigningKey};
|
|
use iam_authz::{PolicyCache, PolicyEvaluator};
|
|
use iam_store::{Backend, BackendConfig, BindingStore, PrincipalStore, RoleStore, TokenStore};
|
|
use iam_types::{Permission, PolicyBinding, Principal, PrincipalRef, Role, Scope};
|
|
use tokio_stream::wrappers::TcpListenerStream;
|
|
use tonic::transport::Server;
|
|
use uuid::Uuid;
|
|
|
|
async fn wait_for_test_tcp(addr: SocketAddr) {
|
|
let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
|
|
loop {
|
|
if tokio::net::TcpStream::connect(addr).await.is_ok() {
|
|
return;
|
|
}
|
|
assert!(
|
|
tokio::time::Instant::now() < deadline,
|
|
"timed out waiting for test listener {}",
|
|
addr
|
|
);
|
|
tokio::time::sleep(Duration::from_millis(25)).await;
|
|
}
|
|
}
|
|
|
|
fn route_config(name: &str, prefix: &str, upstream: &str, strip_prefix: bool) -> RouteConfig {
|
|
RouteConfig {
|
|
name: name.to_string(),
|
|
path_prefix: prefix.to_string(),
|
|
upstream: upstream.to_string(),
|
|
strip_prefix,
|
|
timeout_ms: None,
|
|
auth: None,
|
|
credit: None,
|
|
}
|
|
}
|
|
|
|
fn auth_decision(allow: bool) -> AuthDecision {
|
|
AuthDecision {
|
|
allow,
|
|
subject: None,
|
|
headers: HashMap::new(),
|
|
reason: None,
|
|
}
|
|
}
|
|
|
|
fn credit_decision(allow: bool) -> CreditDecision {
|
|
CreditDecision {
|
|
allow,
|
|
reservation_id: "resv".to_string(),
|
|
reason: None,
|
|
}
|
|
}
|
|
|
|
async fn start_upstream() -> SocketAddr {
|
|
let app = Router::new()
|
|
.route("/v1/echo", get(|| async { "ok" }))
|
|
.route(
|
|
"/v1/echo-auth",
|
|
get(|headers: HeaderMap| async move {
|
|
Json(serde_json::json!({
|
|
"authorization": headers
|
|
.get(axum::http::header::AUTHORIZATION)
|
|
.and_then(|value| value.to_str().ok()),
|
|
"photon_token": headers
|
|
.get(PHOTON_AUTH_TOKEN_HEADER)
|
|
.and_then(|value| value.to_str().ok()),
|
|
}))
|
|
}),
|
|
);
|
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
|
.await
|
|
.expect("bind upstream");
|
|
let addr = listener.local_addr().expect("upstream addr");
|
|
tokio::spawn(async move {
|
|
axum::serve(listener, app).await.expect("upstream serve");
|
|
});
|
|
wait_for_test_tcp(addr).await;
|
|
addr
|
|
}
|
|
|
|
async fn start_iam_gateway() -> (SocketAddr, String) {
|
|
let backend = Arc::new(Backend::new(BackendConfig::Memory).await.expect("iam backend"));
|
|
let principal_store = Arc::new(PrincipalStore::new(backend.clone()));
|
|
let role_store = Arc::new(RoleStore::new(backend.clone()));
|
|
let binding_store = Arc::new(BindingStore::new(backend.clone()));
|
|
let token_store = Arc::new(TokenStore::new(backend));
|
|
|
|
let signing_key = SigningKey::generate("gateway-test-key");
|
|
let token_config = InternalTokenConfig::new(signing_key, "iam-gateway-test");
|
|
let token_service = Arc::new(InternalTokenService::new(token_config));
|
|
|
|
let mut principal = Principal::new_user("user-1", "User One");
|
|
principal.org_id = Some("org-1".into());
|
|
principal.project_id = Some("proj-1".into());
|
|
principal_store
|
|
.create(&principal)
|
|
.await
|
|
.expect("principal create");
|
|
|
|
let role = Role::new(
|
|
"GatewayReader",
|
|
Scope::project("proj-1", "org-1"),
|
|
vec![Permission::new("gateway:public:read", "*")],
|
|
);
|
|
role_store.create(&role).await.expect("role create");
|
|
let binding = PolicyBinding::new(
|
|
format!("binding-{}", Uuid::new_v4()),
|
|
PrincipalRef::new(principal.kind.clone(), principal.id.clone()),
|
|
role.to_ref(),
|
|
Scope::project("proj-1", "org-1"),
|
|
);
|
|
binding_store
|
|
.create(&binding)
|
|
.await
|
|
.expect("binding create");
|
|
|
|
let issued = token_service
|
|
.issue(&principal, vec![], Scope::project("proj-1", "org-1"), None)
|
|
.await
|
|
.expect("issue token");
|
|
|
|
let cache = Arc::new(PolicyCache::default_config());
|
|
let evaluator = Arc::new(PolicyEvaluator::new(
|
|
binding_store.clone(),
|
|
role_store.clone(),
|
|
cache,
|
|
));
|
|
let gateway_auth = GatewayAuthServiceImpl::new(
|
|
token_service,
|
|
principal_store,
|
|
token_store,
|
|
evaluator,
|
|
);
|
|
|
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
|
.await
|
|
.expect("bind iam");
|
|
let addr = listener.local_addr().expect("iam addr");
|
|
tokio::spawn(async move {
|
|
Server::builder()
|
|
.add_service(GatewayAuthServiceServer::new(gateway_auth))
|
|
.serve_with_incoming(TcpListenerStream::new(listener))
|
|
.await
|
|
.expect("iam gateway serve");
|
|
});
|
|
|
|
wait_for_test_tcp(addr).await;
|
|
(addr, issued.token)
|
|
}
|
|
|
|
async fn start_credit_gateway(iam_addr: &SocketAddr) -> SocketAddr {
|
|
let storage = creditservice_api::InMemoryStorage::new();
|
|
let wallet = Wallet::new("proj-1".into(), "org-1".into(), 100);
|
|
storage
|
|
.create_wallet(wallet)
|
|
.await
|
|
.expect("wallet create");
|
|
|
|
let auth_service = Arc::new(
|
|
iam_service_auth::AuthService::new(&format!("http://{}", iam_addr))
|
|
.await
|
|
.expect("auth service"),
|
|
);
|
|
|
|
let credit_service = Arc::new(CreditServiceImpl::new(storage, auth_service));
|
|
let gateway_credit = GatewayCreditServiceImpl::new(credit_service);
|
|
|
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
|
.await
|
|
.expect("bind credit");
|
|
let addr = listener.local_addr().expect("credit addr");
|
|
tokio::spawn(async move {
|
|
Server::builder()
|
|
.add_service(GatewayCreditServiceServer::new(gateway_credit))
|
|
.serve_with_incoming(TcpListenerStream::new(listener))
|
|
.await
|
|
.expect("credit gateway serve");
|
|
});
|
|
|
|
wait_for_test_tcp(addr).await;
|
|
addr
|
|
}
|
|
|
|
#[test]
|
|
fn test_normalize_path_prefix() {
|
|
assert_eq!(normalize_path_prefix(""), "/");
|
|
assert_eq!(normalize_path_prefix("/"), "/");
|
|
assert_eq!(normalize_path_prefix("api"), "/api");
|
|
assert_eq!(normalize_path_prefix("/api/"), "/api");
|
|
}
|
|
|
|
#[test]
|
|
fn test_strip_prefix_path() {
|
|
assert_eq!(strip_prefix_path("/api", "/api"), "/");
|
|
assert_eq!(strip_prefix_path("/api/v1", "/api"), "/v1");
|
|
assert_eq!(strip_prefix_path("/api/v1/", "/api"), "/v1/");
|
|
assert_eq!(strip_prefix_path("/v1", "/"), "/v1");
|
|
}
|
|
|
|
#[test]
|
|
fn test_join_paths() {
|
|
assert_eq!(join_paths("/", "/v1"), "/v1");
|
|
assert_eq!(join_paths("/api", "/v1"), "/api/v1");
|
|
assert_eq!(join_paths("/api/", "/"), "/api");
|
|
}
|
|
|
|
#[test]
|
|
fn test_match_route_longest_prefix() {
|
|
let routes = build_routes(vec![
|
|
route_config("api", "/api", "http://example.com", false),
|
|
route_config("api-v1", "/api/v1", "http://example.com", false),
|
|
])
|
|
.unwrap();
|
|
|
|
let matched = match_route(&routes, "/api/v1/users").unwrap();
|
|
assert_eq!(matched.config.name, "api-v1");
|
|
}
|
|
|
|
#[test]
|
|
fn test_match_route_segment_boundary() {
|
|
let routes = build_routes(vec![
|
|
route_config("api", "/api", "http://example.com", false),
|
|
route_config("api2", "/api2", "http://example.com", false),
|
|
])
|
|
.unwrap();
|
|
|
|
let matched = match_route(&routes, "/api2").unwrap();
|
|
assert_eq!(matched.config.name, "api2");
|
|
|
|
let matched = match_route(&routes, "/api2/health").unwrap();
|
|
assert_eq!(matched.config.name, "api2");
|
|
|
|
assert!(match_route(&routes, "/apiary").is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_build_upstream_url_preserves_query() {
|
|
let routes = build_routes(vec![route_config(
|
|
"api",
|
|
"/api",
|
|
"http://example.com/base",
|
|
false,
|
|
)])
|
|
.unwrap();
|
|
let route = routes.first().unwrap();
|
|
let uri: Uri = "/api/v1/users?debug=true".parse().unwrap();
|
|
let url = build_upstream_url(route, &uri).unwrap();
|
|
assert_eq!(url.as_str(), "http://example.com/base/api/v1/users?debug=true");
|
|
}
|
|
|
|
#[test]
|
|
fn test_build_upstream_url_strip_prefix() {
|
|
let routes = build_routes(vec![route_config(
|
|
"api",
|
|
"/api",
|
|
"http://example.com/base",
|
|
true,
|
|
)])
|
|
.unwrap();
|
|
let route = routes.first().unwrap();
|
|
let uri: Uri = "/api/v1".parse().unwrap();
|
|
let url = build_upstream_url(route, &uri).unwrap();
|
|
assert_eq!(url.as_str(), "http://example.com/base/v1");
|
|
}
|
|
|
|
#[test]
|
|
fn test_apply_auth_mode_required() {
|
|
let decision = Ok(auth_decision(true));
|
|
let outcome = apply_auth_mode(PolicyMode::Required, false, decision).unwrap();
|
|
assert!(outcome.subject.is_none());
|
|
|
|
let decision = Ok(auth_decision(false));
|
|
let err = apply_auth_mode(PolicyMode::Required, false, decision).unwrap_err();
|
|
assert_eq!(err, StatusCode::FORBIDDEN);
|
|
}
|
|
|
|
#[test]
|
|
fn test_apply_auth_mode_optional() {
|
|
let decision = Ok(auth_decision(false));
|
|
let outcome = apply_auth_mode(PolicyMode::Optional, false, decision).unwrap();
|
|
assert!(outcome.subject.is_none());
|
|
|
|
let outcome = apply_auth_mode(PolicyMode::Optional, false, Err(StatusCode::BAD_GATEWAY)).unwrap();
|
|
assert!(outcome.subject.is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_apply_credit_mode_required() {
|
|
let decision = Ok(credit_decision(true));
|
|
let outcome = apply_credit_mode(PolicyMode::Required, false, decision).unwrap();
|
|
assert!(outcome.is_some());
|
|
|
|
let decision = Ok(credit_decision(false));
|
|
let err = apply_credit_mode(PolicyMode::Required, false, decision).unwrap_err();
|
|
assert_eq!(err, StatusCode::PAYMENT_REQUIRED);
|
|
}
|
|
|
|
#[test]
|
|
fn test_apply_credit_mode_optional() {
|
|
let decision = Ok(credit_decision(false));
|
|
let outcome = apply_credit_mode(PolicyMode::Optional, false, decision).unwrap();
|
|
assert!(outcome.is_none());
|
|
|
|
let outcome = apply_credit_mode(PolicyMode::Optional, false, Err(StatusCode::BAD_GATEWAY)).unwrap();
|
|
assert!(outcome.is_none());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_gateway_auth_and_credit_flow() {
|
|
let upstream_addr = start_upstream().await;
|
|
let (iam_addr, token) = start_iam_gateway().await;
|
|
let credit_addr = start_credit_gateway(&iam_addr).await;
|
|
|
|
let routes = build_routes(vec![RouteConfig {
|
|
name: "public".to_string(),
|
|
path_prefix: "/v1".to_string(),
|
|
upstream: format!("http://{}", upstream_addr),
|
|
strip_prefix: false,
|
|
timeout_ms: None,
|
|
auth: Some(RouteAuthConfig {
|
|
provider: "iam".to_string(),
|
|
mode: PolicyMode::Required,
|
|
fail_open: false,
|
|
}),
|
|
credit: Some(RouteCreditConfig {
|
|
provider: "credit".to_string(),
|
|
mode: PolicyMode::Required,
|
|
units: 1,
|
|
fail_open: false,
|
|
commit_on: CommitPolicy::Success,
|
|
allow_header_subject: false,
|
|
attributes: HashMap::new(),
|
|
}),
|
|
}])
|
|
.unwrap();
|
|
|
|
let auth_providers = build_auth_providers(vec![AuthProviderConfig {
|
|
name: "iam".to_string(),
|
|
provider_type: "grpc".to_string(),
|
|
endpoint: format!("http://{}", iam_addr),
|
|
timeout_ms: Some(1000),
|
|
tls: None,
|
|
}])
|
|
.await
|
|
.unwrap();
|
|
|
|
let credit_providers = build_credit_providers(vec![CreditProviderConfig {
|
|
name: "credit".to_string(),
|
|
provider_type: "grpc".to_string(),
|
|
endpoint: format!("http://{}", credit_addr),
|
|
timeout_ms: Some(1000),
|
|
tls: None,
|
|
}])
|
|
.await
|
|
.unwrap();
|
|
|
|
let state = Arc::new(ServerState {
|
|
routes,
|
|
client: Client::new(),
|
|
upstream_timeout: Duration::from_secs(5),
|
|
max_body_bytes: 1024 * 1024,
|
|
max_response_bytes: 1024 * 1024,
|
|
auth_providers,
|
|
credit_providers,
|
|
trust_forwarded_headers: false,
|
|
});
|
|
|
|
let deadline = tokio::time::Instant::now() + Duration::from_secs(10);
|
|
let mut response = None;
|
|
while tokio::time::Instant::now() < deadline {
|
|
let request = Request::builder()
|
|
.method("GET")
|
|
.uri("/v1/echo")
|
|
.header(axum::http::header::AUTHORIZATION, &token)
|
|
.body(Body::empty())
|
|
.expect("request build");
|
|
|
|
match proxy(
|
|
State(Arc::clone(&state)),
|
|
ConnectInfo("127.0.0.1:40000".parse().unwrap()),
|
|
request,
|
|
)
|
|
.await
|
|
{
|
|
Ok(ok) => {
|
|
response = Some(ok);
|
|
break;
|
|
}
|
|
Err(StatusCode::BAD_GATEWAY) => {
|
|
tokio::time::sleep(Duration::from_millis(25)).await;
|
|
}
|
|
Err(status) => panic!("unexpected proxy status: {}", status),
|
|
}
|
|
}
|
|
let response = response.expect("gateway auth+credit test timed out waiting for ready backends");
|
|
assert_eq!(response.status(), StatusCode::OK);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_proxy_forwards_client_auth_headers_when_route_has_no_auth() {
|
|
let upstream_addr = start_upstream().await;
|
|
let routes = build_routes(vec![route_config(
|
|
"passthrough",
|
|
"/v1",
|
|
&format!("http://{}", upstream_addr),
|
|
false,
|
|
)])
|
|
.unwrap();
|
|
|
|
let state = Arc::new(ServerState {
|
|
routes,
|
|
client: Client::new(),
|
|
upstream_timeout: Duration::from_secs(5),
|
|
max_body_bytes: 1024 * 1024,
|
|
max_response_bytes: 1024 * 1024,
|
|
auth_providers: HashMap::new(),
|
|
credit_providers: HashMap::new(),
|
|
trust_forwarded_headers: false,
|
|
});
|
|
|
|
let request = Request::builder()
|
|
.method("GET")
|
|
.uri("/v1/echo-auth")
|
|
.header(axum::http::header::AUTHORIZATION, "Bearer passthrough-token")
|
|
.header(PHOTON_AUTH_TOKEN_HEADER, "photon-token")
|
|
.body(Body::empty())
|
|
.expect("request build");
|
|
|
|
let response = proxy(
|
|
State(state),
|
|
ConnectInfo("127.0.0.1:40000".parse().unwrap()),
|
|
request,
|
|
)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(response.status(), StatusCode::OK);
|
|
|
|
let body = to_bytes(response.into_body(), 1024 * 1024).await.unwrap();
|
|
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
|
assert_eq!(json.get("authorization").and_then(|v| v.as_str()), Some("Bearer passthrough-token"));
|
|
assert_eq!(json.get("photon_token").and_then(|v| v.as_str()), Some("photon-token"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_extract_auth_token_accepts_bearer_authorization() {
|
|
let mut headers = HeaderMap::new();
|
|
headers.insert(
|
|
axum::http::header::AUTHORIZATION,
|
|
"Bearer abc123".parse().unwrap(),
|
|
);
|
|
|
|
assert_eq!(extract_auth_token(&headers).as_deref(), Some("abc123"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_extract_auth_token_accepts_legacy_raw_authorization() {
|
|
let mut headers = HeaderMap::new();
|
|
headers.insert(
|
|
axum::http::header::AUTHORIZATION,
|
|
"raw-token".parse().unwrap(),
|
|
);
|
|
|
|
assert_eq!(extract_auth_token(&headers).as_deref(), Some("raw-token"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_extract_auth_token_falls_back_to_photon_header() {
|
|
let mut headers = HeaderMap::new();
|
|
headers.insert(
|
|
axum::http::header::AUTHORIZATION,
|
|
"Basic abc".parse().unwrap(),
|
|
);
|
|
headers.insert(PHOTON_AUTH_TOKEN_HEADER, "photon-token".parse().unwrap());
|
|
|
|
assert_eq!(
|
|
extract_auth_token(&headers).as_deref(),
|
|
Some("photon-token")
|
|
);
|
|
}
|
|
}
|