//! FlashDNS authoritative DNS server binary use flashdns_api::{RecordServiceServer, ZoneServiceServer}; use flashdns_server::{ config::{MetadataBackend, ServerConfig}, dns::DnsHandler, metadata::DnsMetadataStore, RecordServiceImpl, ZoneServiceImpl, }; use chainfire_client::Client as ChainFireClient; use iam_service_auth::AuthService; use metrics_exporter_prometheus::PrometheusBuilder; use std::sync::Arc; use tonic::transport::{Certificate, Identity, Server, ServerTlsConfig}; use tonic::{Request, Status}; use tonic_health::server::health_reporter; use tracing_subscriber::EnvFilter; use anyhow::Result; use clap::Parser; use std::path::PathBuf; use std::time::{SystemTime, UNIX_EPOCH}; use config::{Config as Cfg, Environment, File, FileFormat}; /// Command-line arguments for FlashDNS server. #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct CliArgs { /// Configuration file path #[arg(short, long, default_value = "flashdns.toml")] config: PathBuf, /// gRPC management API address (overrides config) #[arg(long)] grpc_addr: Option, /// DNS UDP address (overrides config) #[arg(long)] dns_addr: Option, /// ChainFire endpoint for cluster coordination (overrides config) #[arg(long, env = "FLASHDNS_CHAINFIRE_ENDPOINT")] chainfire_endpoint: Option, /// FlareDB endpoint for metadata and tenant data storage (overrides config) #[arg(long, env = "FLASHDNS_FLAREDB_ENDPOINT")] flaredb_endpoint: Option, /// Metadata backend (flaredb, postgres, sqlite) #[arg(long, env = "FLASHDNS_METADATA_BACKEND")] metadata_backend: Option, /// SQL database URL for metadata (required for postgres/sqlite backend) #[arg(long, env = "FLASHDNS_METADATA_DATABASE_URL")] metadata_database_url: Option, /// Run in single-node mode (required when metadata backend is SQLite) #[arg(long, env = "FLASHDNS_SINGLE_NODE")] single_node: bool, /// Log level (overrides config) #[arg(short, long)] log_level: Option, /// Metrics port for Prometheus scraping #[arg(long, default_value = "9097")] metrics_port: u16, } #[tokio::main] async fn main() -> Result<(), Box> { let cli_args = CliArgs::parse(); // Load configuration using config-rs let mut settings = Cfg::builder() // Layer 1: Application defaults. Serialize ServerConfig::default() into TOML. .add_source(File::from_str( toml::to_string(&ServerConfig::default())?.as_str(), FileFormat::Toml, )) // Layer 2: Environment variables (e.g., FLASHDNS_GRPC_ADDR, FLASHDNS_LOG_LEVEL) .add_source( Environment::with_prefix("FLASHDNS") .separator("__") // Use double underscore for nested fields ); // Layer 3: Configuration file (if specified) if cli_args.config.exists() { tracing::info!("Loading config from file: {}", cli_args.config.display()); settings = settings.add_source(File::from(cli_args.config.as_path())); } else { tracing::info!("Config file not found, using defaults and environment variables."); } let mut config: ServerConfig = settings .build()? .try_deserialize() .map_err(|e| anyhow::anyhow!("Failed to load configuration: {}", e))?; // Apply command line overrides (Layer 4: highest precedence) if let Some(grpc_addr_str) = cli_args.grpc_addr { config.grpc_addr = grpc_addr_str.parse()?; } if let Some(dns_addr_str) = cli_args.dns_addr { config.dns_addr = dns_addr_str.parse()?; } if let Some(chainfire_endpoint) = cli_args.chainfire_endpoint { config.chainfire_endpoint = Some(chainfire_endpoint); } if let Some(flaredb_endpoint) = cli_args.flaredb_endpoint { config.flaredb_endpoint = Some(flaredb_endpoint); } if let Some(metadata_backend) = cli_args.metadata_backend { config.metadata_backend = parse_metadata_backend(&metadata_backend)?; } if let Some(metadata_database_url) = cli_args.metadata_database_url { config.metadata_database_url = Some(metadata_database_url); } if cli_args.single_node { config.single_node = true; } if let Some(log_level) = cli_args.log_level { config.log_level = log_level; } // Initialize tracing tracing_subscriber::fmt() .with_env_filter( EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&config.log_level)), ) .init(); tracing::info!("Starting FlashDNS server"); tracing::info!(" gRPC: {}", config.grpc_addr); tracing::info!(" DNS UDP: {}", config.dns_addr); // Initialize Prometheus metrics exporter let metrics_addr = format!("0.0.0.0:{}", cli_args.metrics_port); let builder = PrometheusBuilder::new(); builder .with_http_listener(metrics_addr.parse::()?) .install() .expect("Failed to install Prometheus metrics exporter"); tracing::info!( "Prometheus metrics available at http://{}/metrics", metrics_addr ); if let Some(endpoint) = &config.chainfire_endpoint { tracing::info!(" Cluster coordination: ChainFire at {}", endpoint); let endpoint = endpoint.clone(); let addr = config.grpc_addr.to_string(); tokio::spawn(async move { if let Err(error) = register_chainfire_membership(&endpoint, "flashdns", addr).await { tracing::warn!(error = %error, "Failed to register ChainFire membership"); } }); } // Create metadata store from explicitly selected backend. let metadata = match config.metadata_backend { MetadataBackend::FlareDb => { if let Some(endpoint) = config.flaredb_endpoint.as_deref() { tracing::info!(" Metadata backend: FlareDB @ {}", endpoint); } else { tracing::info!(" Metadata backend: FlareDB"); } Arc::new( DnsMetadataStore::new_flaredb_with_pd( config.flaredb_endpoint.clone(), config.chainfire_endpoint.clone(), ) .await .map_err(|e| anyhow::anyhow!("Failed to initialize FlareDB metadata store: {}", e))?, ) } MetadataBackend::Postgres | MetadataBackend::Sqlite => { let database_url = config .metadata_database_url .as_deref() .ok_or_else(|| { anyhow::anyhow!( "metadata_database_url is required when metadata_backend={} (env: FLASHDNS_METADATA_DATABASE_URL)", metadata_backend_name(config.metadata_backend) ) })?; ensure_sql_backend_matches_url(config.metadata_backend, database_url)?; tracing::info!( " Metadata backend: {} @ {}", metadata_backend_name(config.metadata_backend), database_url ); Arc::new( DnsMetadataStore::new_sql(database_url, config.single_node) .await .map_err(|e| anyhow::anyhow!("Failed to initialize SQL metadata store: {}", e))?, ) } }; // Initialize IAM authentication service tracing::info!("Connecting to IAM server at {}", config.auth.iam_server_addr); let auth_service = AuthService::new(&config.auth.iam_server_addr) .await .map_err(|e| anyhow::anyhow!("Failed to connect to IAM server: {}", e))?; let auth_service = Arc::new(auth_service); // Dedicated runtime for auth interceptors to avoid blocking the main async runtime let auth_runtime = Arc::new(tokio::runtime::Runtime::new()?); let make_interceptor = |auth: Arc| { let rt = auth_runtime.clone(); move |mut req: Request<()>| -> Result, Status> { let auth = auth.clone(); tokio::task::block_in_place(|| { rt.block_on(async move { let tenant_context = auth.authenticate_request(&req).await?; req.extensions_mut().insert(tenant_context); Ok(req) }) }) } }; // Create gRPC services let zone_service = ZoneServiceImpl::new(metadata.clone(), auth_service.clone()); let record_service = RecordServiceImpl::new(metadata.clone(), auth_service.clone()); // Setup health service let (mut health_reporter, health_service) = health_reporter(); health_reporter .set_serving::>() .await; health_reporter .set_serving::>() .await; // Start DNS handler let dns_handler = DnsHandler::bind(config.dns_addr, metadata.clone()).await?; let dns_task = tokio::spawn(async move { dns_handler.run().await; }); // Configure TLS if enabled let mut server = Server::builder(); if let Some(tls_config) = &config.tls { tracing::info!("TLS enabled, loading certificates..."); let cert = tokio::fs::read(&tls_config.cert_file).await?; let key = tokio::fs::read(&tls_config.key_file).await?; let server_identity = Identity::from_pem(cert, key); let tls = if tls_config.require_client_cert { tracing::info!("mTLS enabled"); let ca_cert = tokio::fs::read( tls_config .ca_file .as_ref() .ok_or("ca_file required for mTLS")?, ) .await?; let ca = Certificate::from_pem(ca_cert); ServerTlsConfig::new() .identity(server_identity) .client_ca_root(ca) } else { ServerTlsConfig::new().identity(server_identity) }; server = server.tls_config(tls)?; } // Start gRPC server tracing::info!("gRPC server listening on {}", config.grpc_addr); let grpc_server = server .add_service(health_service) .add_service(tonic::codegen::InterceptedService::new( ZoneServiceServer::new(zone_service), make_interceptor(auth_service.clone()), )) .add_service(tonic::codegen::InterceptedService::new( RecordServiceServer::new(record_service), make_interceptor(auth_service.clone()), )) .serve(config.grpc_addr); // Run both servers tokio::select! { result = grpc_server => { if let Err(e) = result { tracing::error!("gRPC server error: {}", e); } } _ = dns_task => { tracing::error!("DNS handler unexpectedly terminated"); } } Ok(()) } fn parse_metadata_backend(value: &str) -> Result { match value.trim().to_ascii_lowercase().as_str() { "flaredb" => Ok(MetadataBackend::FlareDb), "postgres" => Ok(MetadataBackend::Postgres), "sqlite" => Ok(MetadataBackend::Sqlite), other => Err(anyhow::anyhow!( "invalid metadata backend '{}'; expected one of: flaredb, postgres, sqlite", other )), } } fn metadata_backend_name(backend: MetadataBackend) -> &'static str { match backend { MetadataBackend::FlareDb => "flaredb", MetadataBackend::Postgres => "postgres", MetadataBackend::Sqlite => "sqlite", } } fn ensure_sql_backend_matches_url(backend: MetadataBackend, database_url: &str) -> Result<()> { let normalized = database_url.trim().to_ascii_lowercase(); match backend { MetadataBackend::Postgres => { if normalized.starts_with("postgres://") || normalized.starts_with("postgresql://") { Ok(()) } else { Err(anyhow::anyhow!( "metadata_backend=postgres requires postgres:// or postgresql:// URL" )) } } MetadataBackend::Sqlite => { if normalized.starts_with("sqlite:") { Ok(()) } else { Err(anyhow::anyhow!( "metadata_backend=sqlite requires sqlite: URL" )) } } MetadataBackend::FlareDb => Ok(()), } } async fn register_chainfire_membership( endpoint: &str, service: &str, addr: String, ) -> Result<()> { let node_id = std::env::var("HOSTNAME").unwrap_or_else(|_| format!("{}-{}", service, std::process::id())); let ts = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap_or_default() .as_secs(); let key = format!("/cluster/{}/members/{}", service, node_id); let value = format!(r#"{{"addr":"{}","ts":{}}}"#, addr, ts); let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(120); let mut attempt = 0usize; let mut last_error = String::new(); loop { attempt += 1; match ChainFireClient::connect(endpoint).await { Ok(mut client) => match client.put_str(&key, &value).await { Ok(_) => return Ok(()), Err(error) => last_error = format!("put failed: {}", error), }, Err(error) => last_error = format!("connect failed: {}", error), } if tokio::time::Instant::now() >= deadline { break; } tracing::warn!( attempt, endpoint, service, error = %last_error, "retrying ChainFire membership registration" ); tokio::time::sleep(std::time::Duration::from_secs(2)).await; } anyhow::bail!( "failed to register ChainFire membership for {} via {} after {} attempts: {}", service, endpoint, attempt, last_error ) }