photoncloud-monorepo/fiberlb/crates/fiberlb-server/src/tls.rs

228 lines
7.5 KiB
Rust

//! TLS Configuration and Certificate Management
//!
//! Provides rustls-based TLS termination with SNI support for L7 HTTPS listeners.
use rustls::crypto::ring::sign::any_supported_type;
use rustls::pki_types::CertificateDer;
use rustls::server::{ClientHello, ResolvesServerCert};
use rustls::sign::CertifiedKey;
use rustls::ServerConfig;
use std::collections::HashMap;
use std::io::Cursor;
use std::sync::Arc;
use fiberlb_types::{Certificate, CertificateId, LoadBalancerId, TlsVersion};
type Result<T> = std::result::Result<T, TlsError>;
#[derive(Debug, thiserror::Error)]
pub enum TlsError {
#[error("Invalid certificate PEM: {0}")]
InvalidCertificate(String),
#[error("Invalid private key PEM: {0}")]
InvalidPrivateKey(String),
#[error("No private key found in PEM")]
NoPrivateKey,
#[error("TLS configuration error: {0}")]
ConfigError(String),
#[error("Certificate not found: {0}")]
CertificateNotFound(String),
}
/// Build TLS server configuration from certificate and private key
pub fn build_tls_config(
cert_pem: &str,
key_pem: &str,
min_version: TlsVersion,
) -> Result<ServerConfig> {
// Parse certificate chain from PEM
let mut cert_reader = Cursor::new(cert_pem.as_bytes());
let certs: Vec<CertificateDer> = rustls_pemfile::certs(&mut cert_reader)
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| TlsError::InvalidCertificate(format!("Failed to parse certificates: {}", e)))?;
if certs.is_empty() {
return Err(TlsError::InvalidCertificate("No certificates found in PEM".to_string()));
}
// Parse private key from PEM
let mut key_reader = Cursor::new(key_pem.as_bytes());
let key = rustls_pemfile::private_key(&mut key_reader)
.map_err(|e| TlsError::InvalidPrivateKey(format!("Failed to parse private key: {}", e)))?
.ok_or(TlsError::NoPrivateKey)?;
let mut config = match min_version {
TlsVersion::Tls12 => ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key.into())
.map_err(|e| TlsError::ConfigError(format!("Failed to build config: {}", e)))?,
TlsVersion::Tls13 => ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
.with_no_client_auth()
.with_single_cert(certs, key.into())
.map_err(|e| TlsError::ConfigError(format!("Failed to build config: {}", e)))?,
};
// Enable ALPN for HTTP/2 and HTTP/1.1
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Ok(config)
}
pub fn build_certified_key(cert_pem: &str, key_pem: &str) -> Result<Arc<CertifiedKey>> {
let mut cert_reader = Cursor::new(cert_pem.as_bytes());
let certs: Vec<CertificateDer> = rustls_pemfile::certs(&mut cert_reader)
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| TlsError::InvalidCertificate(format!("Failed to parse certificates: {}", e)))?;
if certs.is_empty() {
return Err(TlsError::InvalidCertificate("No certificates found in PEM".to_string()));
}
let mut key_reader = Cursor::new(key_pem.as_bytes());
let key = rustls_pemfile::private_key(&mut key_reader)
.map_err(|e| TlsError::InvalidPrivateKey(format!("Failed to parse private key: {}", e)))?
.ok_or(TlsError::NoPrivateKey)?;
let signing_key = any_supported_type(&key)
.map_err(|e| TlsError::InvalidPrivateKey(format!("Unsupported key: {}", e)))?;
Ok(Arc::new(CertifiedKey::new(certs, signing_key)))
}
/// SNI-based certificate resolver for multiple domains
///
/// Allows a single listener to serve multiple domains with different certificates
/// based on the SNI (Server Name Indication) extension in the TLS handshake.
#[derive(Debug)]
pub struct SniCertResolver {
/// Map of SNI hostname -> CertifiedKey
certs: HashMap<String, Arc<CertifiedKey>>,
/// Default certificate when SNI doesn't match
default: Arc<CertifiedKey>,
}
impl SniCertResolver {
/// Create a new SNI resolver with a default certificate
pub fn new(default_cert: Arc<CertifiedKey>) -> Self {
Self {
certs: HashMap::new(),
default: default_cert,
}
}
/// Add a certificate for a specific SNI hostname
pub fn add_cert(&mut self, hostname: String, cert: Arc<CertifiedKey>) {
self.certs.insert(hostname, cert);
}
/// Get certificate for a hostname
pub fn get_cert(&self, hostname: &str) -> Arc<CertifiedKey> {
self.certs
.get(hostname)
.cloned()
.unwrap_or_else(|| self.default.clone())
}
}
impl ResolvesServerCert for SniCertResolver {
fn resolve(&self, client_hello: ClientHello) -> Option<Arc<CertifiedKey>> {
let sni = client_hello.server_name()?;
Some(self.get_cert(sni))
}
}
/// Certificate store for managing TLS certificates
pub struct CertificateStore {
certificates: HashMap<CertificateId, Certificate>,
}
impl CertificateStore {
/// Create a new empty certificate store
pub fn new() -> Self {
Self {
certificates: HashMap::new(),
}
}
/// Add a certificate to the store
pub fn add(&mut self, cert: Certificate) {
self.certificates.insert(cert.id, cert);
}
/// Get a certificate by ID
pub fn get(&self, id: &CertificateId) -> Option<&Certificate> {
self.certificates.get(id)
}
/// List all certificates for a load balancer
pub fn list_for_lb(&self, lb_id: &LoadBalancerId) -> Vec<&Certificate> {
self.certificates
.values()
.filter(|cert| cert.loadbalancer_id == *lb_id)
.collect()
}
/// Remove a certificate
pub fn remove(&mut self, id: &CertificateId) -> Option<Certificate> {
self.certificates.remove(id)
}
/// Build TLS configuration from a certificate ID
pub fn build_config(
&self,
cert_id: &CertificateId,
min_version: TlsVersion,
) -> Result<ServerConfig> {
let cert = self
.get(cert_id)
.ok_or_else(|| TlsError::CertificateNotFound(cert_id.to_string()))?;
build_tls_config(&cert.certificate, &cert.private_key, min_version)
}
pub fn build_certified_key(&self, cert_id: &CertificateId) -> Result<Arc<CertifiedKey>> {
let cert = self
.get(cert_id)
.ok_or_else(|| TlsError::CertificateNotFound(cert_id.to_string()))?;
build_certified_key(&cert.certificate, &cert.private_key)
}
}
impl Default for CertificateStore {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_certificate_store() {
let mut store = CertificateStore::new();
let lb_id = LoadBalancerId::new();
let cert = Certificate {
id: CertificateId::new(),
loadbalancer_id: lb_id,
name: "test-cert".to_string(),
certificate: "-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----".to_string(),
private_key: "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----".to_string(),
cert_type: fiberlb_types::CertificateType::Server,
expires_at: 0,
created_at: 0,
updated_at: 0,
};
store.add(cert.clone());
assert!(store.get(&cert.id).is_some());
assert_eq!(store.list_for_lb(&lb_id).len(), 1);
let removed = store.remove(&cert.id);
assert!(removed.is_some());
assert!(store.get(&cert.id).is_none());
}
}