//! TLS Configuration and Certificate Management //! //! Provides rustls-based TLS termination with SNI support for L7 HTTPS listeners. use rustls::crypto::ring::sign::any_supported_type; use rustls::pki_types::CertificateDer; use rustls::server::{ClientHello, ResolvesServerCert}; use rustls::sign::CertifiedKey; use rustls::ServerConfig; use std::collections::HashMap; use std::io::Cursor; use std::sync::Arc; use fiberlb_types::{Certificate, CertificateId, LoadBalancerId, TlsVersion}; type Result = std::result::Result; #[derive(Debug, thiserror::Error)] pub enum TlsError { #[error("Invalid certificate PEM: {0}")] InvalidCertificate(String), #[error("Invalid private key PEM: {0}")] InvalidPrivateKey(String), #[error("No private key found in PEM")] NoPrivateKey, #[error("TLS configuration error: {0}")] ConfigError(String), #[error("Certificate not found: {0}")] CertificateNotFound(String), } /// Build TLS server configuration from certificate and private key pub fn build_tls_config( cert_pem: &str, key_pem: &str, min_version: TlsVersion, ) -> Result { // Parse certificate chain from PEM let mut cert_reader = Cursor::new(cert_pem.as_bytes()); let certs: Vec = rustls_pemfile::certs(&mut cert_reader) .collect::, _>>() .map_err(|e| TlsError::InvalidCertificate(format!("Failed to parse certificates: {}", e)))?; if certs.is_empty() { return Err(TlsError::InvalidCertificate("No certificates found in PEM".to_string())); } // Parse private key from PEM let mut key_reader = Cursor::new(key_pem.as_bytes()); let key = rustls_pemfile::private_key(&mut key_reader) .map_err(|e| TlsError::InvalidPrivateKey(format!("Failed to parse private key: {}", e)))? .ok_or(TlsError::NoPrivateKey)?; let mut config = match min_version { TlsVersion::Tls12 => ServerConfig::builder() .with_no_client_auth() .with_single_cert(certs, key.into()) .map_err(|e| TlsError::ConfigError(format!("Failed to build config: {}", e)))?, TlsVersion::Tls13 => ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) .with_no_client_auth() .with_single_cert(certs, key.into()) .map_err(|e| TlsError::ConfigError(format!("Failed to build config: {}", e)))?, }; // Enable ALPN for HTTP/2 and HTTP/1.1 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; Ok(config) } pub fn build_certified_key(cert_pem: &str, key_pem: &str) -> Result> { let mut cert_reader = Cursor::new(cert_pem.as_bytes()); let certs: Vec = rustls_pemfile::certs(&mut cert_reader) .collect::, _>>() .map_err(|e| TlsError::InvalidCertificate(format!("Failed to parse certificates: {}", e)))?; if certs.is_empty() { return Err(TlsError::InvalidCertificate("No certificates found in PEM".to_string())); } let mut key_reader = Cursor::new(key_pem.as_bytes()); let key = rustls_pemfile::private_key(&mut key_reader) .map_err(|e| TlsError::InvalidPrivateKey(format!("Failed to parse private key: {}", e)))? .ok_or(TlsError::NoPrivateKey)?; let signing_key = any_supported_type(&key) .map_err(|e| TlsError::InvalidPrivateKey(format!("Unsupported key: {}", e)))?; Ok(Arc::new(CertifiedKey::new(certs, signing_key))) } /// SNI-based certificate resolver for multiple domains /// /// Allows a single listener to serve multiple domains with different certificates /// based on the SNI (Server Name Indication) extension in the TLS handshake. #[derive(Debug)] pub struct SniCertResolver { /// Map of SNI hostname -> CertifiedKey certs: HashMap>, /// Default certificate when SNI doesn't match default: Arc, } impl SniCertResolver { /// Create a new SNI resolver with a default certificate pub fn new(default_cert: Arc) -> Self { Self { certs: HashMap::new(), default: default_cert, } } /// Add a certificate for a specific SNI hostname pub fn add_cert(&mut self, hostname: String, cert: Arc) { self.certs.insert(hostname, cert); } /// Get certificate for a hostname pub fn get_cert(&self, hostname: &str) -> Arc { self.certs .get(hostname) .cloned() .unwrap_or_else(|| self.default.clone()) } } impl ResolvesServerCert for SniCertResolver { fn resolve(&self, client_hello: ClientHello) -> Option> { let sni = client_hello.server_name()?; Some(self.get_cert(sni)) } } /// Certificate store for managing TLS certificates pub struct CertificateStore { certificates: HashMap, } impl CertificateStore { /// Create a new empty certificate store pub fn new() -> Self { Self { certificates: HashMap::new(), } } /// Add a certificate to the store pub fn add(&mut self, cert: Certificate) { self.certificates.insert(cert.id, cert); } /// Get a certificate by ID pub fn get(&self, id: &CertificateId) -> Option<&Certificate> { self.certificates.get(id) } /// List all certificates for a load balancer pub fn list_for_lb(&self, lb_id: &LoadBalancerId) -> Vec<&Certificate> { self.certificates .values() .filter(|cert| cert.loadbalancer_id == *lb_id) .collect() } /// Remove a certificate pub fn remove(&mut self, id: &CertificateId) -> Option { self.certificates.remove(id) } /// Build TLS configuration from a certificate ID pub fn build_config( &self, cert_id: &CertificateId, min_version: TlsVersion, ) -> Result { let cert = self .get(cert_id) .ok_or_else(|| TlsError::CertificateNotFound(cert_id.to_string()))?; build_tls_config(&cert.certificate, &cert.private_key, min_version) } pub fn build_certified_key(&self, cert_id: &CertificateId) -> Result> { let cert = self .get(cert_id) .ok_or_else(|| TlsError::CertificateNotFound(cert_id.to_string()))?; build_certified_key(&cert.certificate, &cert.private_key) } } impl Default for CertificateStore { fn default() -> Self { Self::new() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_certificate_store() { let mut store = CertificateStore::new(); let lb_id = LoadBalancerId::new(); let cert = Certificate { id: CertificateId::new(), loadbalancer_id: lb_id, name: "test-cert".to_string(), certificate: "-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----".to_string(), private_key: "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----".to_string(), cert_type: fiberlb_types::CertificateType::Server, expires_at: 0, created_at: 0, updated_at: 0, }; store.add(cert.clone()); assert!(store.get(&cert.id).is_some()); assert_eq!(store.list_for_lb(&lb_id).len(), 1); let removed = store.remove(&cert.id); assert!(removed.is_some()); assert!(store.get(&cert.id).is_none()); } }