228 lines
7.5 KiB
Rust
228 lines
7.5 KiB
Rust
//! TLS Configuration and Certificate Management
|
|
//!
|
|
//! Provides rustls-based TLS termination with SNI support for L7 HTTPS listeners.
|
|
|
|
use rustls::crypto::ring::sign::any_supported_type;
|
|
use rustls::pki_types::CertificateDer;
|
|
use rustls::server::{ClientHello, ResolvesServerCert};
|
|
use rustls::sign::CertifiedKey;
|
|
use rustls::ServerConfig;
|
|
use std::collections::HashMap;
|
|
use std::io::Cursor;
|
|
use std::sync::Arc;
|
|
|
|
use fiberlb_types::{Certificate, CertificateId, LoadBalancerId, TlsVersion};
|
|
|
|
type Result<T> = std::result::Result<T, TlsError>;
|
|
|
|
#[derive(Debug, thiserror::Error)]
|
|
pub enum TlsError {
|
|
#[error("Invalid certificate PEM: {0}")]
|
|
InvalidCertificate(String),
|
|
#[error("Invalid private key PEM: {0}")]
|
|
InvalidPrivateKey(String),
|
|
#[error("No private key found in PEM")]
|
|
NoPrivateKey,
|
|
#[error("TLS configuration error: {0}")]
|
|
ConfigError(String),
|
|
#[error("Certificate not found: {0}")]
|
|
CertificateNotFound(String),
|
|
}
|
|
|
|
/// Build TLS server configuration from certificate and private key
|
|
pub fn build_tls_config(
|
|
cert_pem: &str,
|
|
key_pem: &str,
|
|
min_version: TlsVersion,
|
|
) -> Result<ServerConfig> {
|
|
// Parse certificate chain from PEM
|
|
let mut cert_reader = Cursor::new(cert_pem.as_bytes());
|
|
let certs: Vec<CertificateDer> = rustls_pemfile::certs(&mut cert_reader)
|
|
.collect::<std::result::Result<Vec<_>, _>>()
|
|
.map_err(|e| TlsError::InvalidCertificate(format!("Failed to parse certificates: {}", e)))?;
|
|
|
|
if certs.is_empty() {
|
|
return Err(TlsError::InvalidCertificate("No certificates found in PEM".to_string()));
|
|
}
|
|
|
|
// Parse private key from PEM
|
|
let mut key_reader = Cursor::new(key_pem.as_bytes());
|
|
let key = rustls_pemfile::private_key(&mut key_reader)
|
|
.map_err(|e| TlsError::InvalidPrivateKey(format!("Failed to parse private key: {}", e)))?
|
|
.ok_or(TlsError::NoPrivateKey)?;
|
|
|
|
let mut config = match min_version {
|
|
TlsVersion::Tls12 => ServerConfig::builder()
|
|
.with_no_client_auth()
|
|
.with_single_cert(certs, key.into())
|
|
.map_err(|e| TlsError::ConfigError(format!("Failed to build config: {}", e)))?,
|
|
TlsVersion::Tls13 => ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
|
|
.with_no_client_auth()
|
|
.with_single_cert(certs, key.into())
|
|
.map_err(|e| TlsError::ConfigError(format!("Failed to build config: {}", e)))?,
|
|
};
|
|
|
|
// Enable ALPN for HTTP/2 and HTTP/1.1
|
|
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
|
|
|
Ok(config)
|
|
}
|
|
|
|
pub fn build_certified_key(cert_pem: &str, key_pem: &str) -> Result<Arc<CertifiedKey>> {
|
|
let mut cert_reader = Cursor::new(cert_pem.as_bytes());
|
|
let certs: Vec<CertificateDer> = rustls_pemfile::certs(&mut cert_reader)
|
|
.collect::<std::result::Result<Vec<_>, _>>()
|
|
.map_err(|e| TlsError::InvalidCertificate(format!("Failed to parse certificates: {}", e)))?;
|
|
|
|
if certs.is_empty() {
|
|
return Err(TlsError::InvalidCertificate("No certificates found in PEM".to_string()));
|
|
}
|
|
|
|
let mut key_reader = Cursor::new(key_pem.as_bytes());
|
|
let key = rustls_pemfile::private_key(&mut key_reader)
|
|
.map_err(|e| TlsError::InvalidPrivateKey(format!("Failed to parse private key: {}", e)))?
|
|
.ok_or(TlsError::NoPrivateKey)?;
|
|
|
|
let signing_key = any_supported_type(&key)
|
|
.map_err(|e| TlsError::InvalidPrivateKey(format!("Unsupported key: {}", e)))?;
|
|
|
|
Ok(Arc::new(CertifiedKey::new(certs, signing_key)))
|
|
}
|
|
|
|
/// SNI-based certificate resolver for multiple domains
|
|
///
|
|
/// Allows a single listener to serve multiple domains with different certificates
|
|
/// based on the SNI (Server Name Indication) extension in the TLS handshake.
|
|
#[derive(Debug)]
|
|
pub struct SniCertResolver {
|
|
/// Map of SNI hostname -> CertifiedKey
|
|
certs: HashMap<String, Arc<CertifiedKey>>,
|
|
/// Default certificate when SNI doesn't match
|
|
default: Arc<CertifiedKey>,
|
|
}
|
|
|
|
impl SniCertResolver {
|
|
/// Create a new SNI resolver with a default certificate
|
|
pub fn new(default_cert: Arc<CertifiedKey>) -> Self {
|
|
Self {
|
|
certs: HashMap::new(),
|
|
default: default_cert,
|
|
}
|
|
}
|
|
|
|
/// Add a certificate for a specific SNI hostname
|
|
pub fn add_cert(&mut self, hostname: String, cert: Arc<CertifiedKey>) {
|
|
self.certs.insert(hostname, cert);
|
|
}
|
|
|
|
/// Get certificate for a hostname
|
|
pub fn get_cert(&self, hostname: &str) -> Arc<CertifiedKey> {
|
|
self.certs
|
|
.get(hostname)
|
|
.cloned()
|
|
.unwrap_or_else(|| self.default.clone())
|
|
}
|
|
}
|
|
|
|
impl ResolvesServerCert for SniCertResolver {
|
|
fn resolve(&self, client_hello: ClientHello) -> Option<Arc<CertifiedKey>> {
|
|
let sni = client_hello.server_name()?;
|
|
Some(self.get_cert(sni))
|
|
}
|
|
}
|
|
|
|
/// Certificate store for managing TLS certificates
|
|
pub struct CertificateStore {
|
|
certificates: HashMap<CertificateId, Certificate>,
|
|
}
|
|
|
|
impl CertificateStore {
|
|
/// Create a new empty certificate store
|
|
pub fn new() -> Self {
|
|
Self {
|
|
certificates: HashMap::new(),
|
|
}
|
|
}
|
|
|
|
/// Add a certificate to the store
|
|
pub fn add(&mut self, cert: Certificate) {
|
|
self.certificates.insert(cert.id, cert);
|
|
}
|
|
|
|
/// Get a certificate by ID
|
|
pub fn get(&self, id: &CertificateId) -> Option<&Certificate> {
|
|
self.certificates.get(id)
|
|
}
|
|
|
|
/// List all certificates for a load balancer
|
|
pub fn list_for_lb(&self, lb_id: &LoadBalancerId) -> Vec<&Certificate> {
|
|
self.certificates
|
|
.values()
|
|
.filter(|cert| cert.loadbalancer_id == *lb_id)
|
|
.collect()
|
|
}
|
|
|
|
/// Remove a certificate
|
|
pub fn remove(&mut self, id: &CertificateId) -> Option<Certificate> {
|
|
self.certificates.remove(id)
|
|
}
|
|
|
|
/// Build TLS configuration from a certificate ID
|
|
pub fn build_config(
|
|
&self,
|
|
cert_id: &CertificateId,
|
|
min_version: TlsVersion,
|
|
) -> Result<ServerConfig> {
|
|
let cert = self
|
|
.get(cert_id)
|
|
.ok_or_else(|| TlsError::CertificateNotFound(cert_id.to_string()))?;
|
|
|
|
build_tls_config(&cert.certificate, &cert.private_key, min_version)
|
|
}
|
|
|
|
pub fn build_certified_key(&self, cert_id: &CertificateId) -> Result<Arc<CertifiedKey>> {
|
|
let cert = self
|
|
.get(cert_id)
|
|
.ok_or_else(|| TlsError::CertificateNotFound(cert_id.to_string()))?;
|
|
|
|
build_certified_key(&cert.certificate, &cert.private_key)
|
|
}
|
|
}
|
|
|
|
impl Default for CertificateStore {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_certificate_store() {
|
|
let mut store = CertificateStore::new();
|
|
|
|
let lb_id = LoadBalancerId::new();
|
|
let cert = Certificate {
|
|
id: CertificateId::new(),
|
|
loadbalancer_id: lb_id,
|
|
name: "test-cert".to_string(),
|
|
certificate: "-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----".to_string(),
|
|
private_key: "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----".to_string(),
|
|
cert_type: fiberlb_types::CertificateType::Server,
|
|
expires_at: 0,
|
|
created_at: 0,
|
|
updated_at: 0,
|
|
};
|
|
|
|
store.add(cert.clone());
|
|
|
|
assert!(store.get(&cert.id).is_some());
|
|
assert_eq!(store.list_for_lb(&lb_id).len(), 1);
|
|
|
|
let removed = store.remove(&cert.id);
|
|
assert!(removed.is_some());
|
|
assert!(store.get(&cert.id).is_none());
|
|
}
|
|
}
|