//! FlareDB storage implementation for CreditService use async_trait::async_trait; use creditservice_types::{Error, Quota, Reservation, ResourceType, Result, Transaction, Wallet}; use flaredb_client::RdbClient; use serde::{Deserialize, Serialize}; use std::sync::Arc; use tokio::sync::Mutex; use tracing::debug; use super::CreditStorage; /// FlareDB storage implementation for CreditService data pub struct FlareDbStorage { client: Arc>, } impl FlareDbStorage { const CAS_RETRY_LIMIT: usize = 8; /// Create a new FlareDB storage pub async fn new(flaredb_endpoint: &str) -> Result> { Self::new_with_pd(flaredb_endpoint, None).await } /// Create a new FlareDB storage with an explicit PD address. pub async fn new_with_pd( flaredb_endpoint: &str, pd_endpoint: Option<&str>, ) -> Result> { let endpoint = normalize_flaredb_endpoint(flaredb_endpoint); let pd_endpoint = pd_endpoint .map(normalize_flaredb_endpoint) .or_else(|| std::env::var("CREDITSERVICE_CHAINFIRE_ENDPOINT").ok()) .map(|value| normalize_flaredb_endpoint(&value)) .unwrap_or_else(|| endpoint.clone()); debug!(endpoint = %endpoint, "Connecting to FlareDB"); let client = RdbClient::connect_with_pd_namespace(endpoint, pd_endpoint, "creditservice") .await .map_err(|e| Error::Storage(format!("Failed to connect to FlareDB: {}", e)))?; Ok(Arc::new(Self { client: Arc::new(Mutex::new(client)), })) } fn wallet_key(project_id: &str) -> String { format!("/creditservice/wallets/{}", project_id) } fn transaction_key(project_id: &str, transaction_id: &str, timestamp_nanos: u64) -> String { format!( "/creditservice/transactions/{}/{}_{}", project_id, timestamp_nanos, transaction_id ) } fn reservation_key(id: &str) -> String { format!("/creditservice/reservations/{}", id) } fn quota_key(project_id: &str, resource_type: ResourceType) -> String { format!("/creditservice/quotas/{}/{}", project_id, resource_type.as_str()) } fn transactions_prefix(project_id: &str) -> String { format!("/creditservice/transactions/{}/", project_id) } fn quotas_prefix(project_id: &str) -> String { format!("/creditservice/quotas/{}/", project_id) } fn reservations_prefix() -> String { "/creditservice/reservations/".to_string() } fn serialize(value: &T) -> Result> { serde_json::to_vec(value) .map_err(|e| Error::Storage(format!("Failed to serialize data: {}", e))) } fn deserialize Deserialize<'de>>(bytes: &[u8]) -> Result { serde_json::from_slice(bytes) .map_err(|e| Error::Storage(format!("Failed to deserialize data: {}", e))) } async fn scan_prefix_values(&self, prefix: &str) -> Result>> { let mut client = self.client.lock().await; let mut start_key = prefix.as_bytes().to_vec(); let end_key = prefix_end_key(prefix.as_bytes()); let mut values = Vec::new(); loop { let (entries, next_key) = client .cas_scan(start_key.clone(), end_key.clone(), 1000) .await .map_err(|e| Error::Storage(e.to_string()))?; values.extend(entries.into_iter().map(|(_, value, _)| value)); if let Some(next) = next_key { start_key = next; } else { break; } } Ok(values) } async fn get_value_with_version(&self, key: &str) -> Result)>> { let mut client = self.client.lock().await; client .cas_get(key.as_bytes().to_vec()) .await .map_err(|e| Error::Storage(e.to_string())) } async fn put_value(&self, key: &str, value: Vec) -> Result<()> { for _ in 0..Self::CAS_RETRY_LIMIT { let expected_version = self .get_value_with_version(key) .await? .map(|(version, _)| version) .unwrap_or(0); let mut client = self.client.lock().await; let (success, _current, _new) = client .cas(key.as_bytes().to_vec(), value.clone(), expected_version) .await .map_err(|e| Error::Storage(e.to_string()))?; if success { return Ok(()); } } Err(Error::Storage(format!( "CAS write retry budget exhausted for key {}", key ))) } async fn delete_value(&self, key: &str) -> Result { for _ in 0..Self::CAS_RETRY_LIMIT { let Some((version, _)) = self.get_value_with_version(key).await? else { return Ok(false); }; let mut client = self.client.lock().await; let (success, _current, existed) = client .cas_delete(key.as_bytes().to_vec(), version) .await .map_err(|e| Error::Storage(e.to_string()))?; if success { return Ok(existed); } } Err(Error::Storage(format!( "CAS delete retry budget exhausted for key {}", key ))) } } #[async_trait] impl CreditStorage for FlareDbStorage { async fn get_wallet(&self, project_id: &str) -> Result> { let key = Self::wallet_key(project_id); self.get_value_with_version(&key) .await? .map(|(_, value)| Self::deserialize(value.as_slice())) .transpose() } async fn create_wallet(&self, wallet: Wallet) -> Result { let key = Self::wallet_key(&wallet.project_id); let serialized_wallet = Self::serialize(&wallet)?; let mut client = self.client.lock().await; let (success, _current, _new) = client .cas(key.as_bytes().to_vec(), serialized_wallet, 0) .await .map_err(|e| Error::Storage(e.to_string()))?; if success { Ok(wallet) } else { Err(Error::WalletAlreadyExists(wallet.project_id)) } } async fn update_wallet(&self, wallet: Wallet) -> Result { let key = Self::wallet_key(&wallet.project_id); let serialized_wallet = Self::serialize(&wallet)?; self.put_value(&key, serialized_wallet).await?; Ok(wallet) } async fn delete_wallet(&self, project_id: &str) -> Result { let key = Self::wallet_key(project_id); self.delete_value(&key).await } async fn add_transaction(&self, transaction: Transaction) -> Result { let key = Self::transaction_key( &transaction.project_id, &transaction.id, transaction.created_at.timestamp_nanos() as u64, ); let serialized_txn = Self::serialize(&transaction)?; self.put_value(&key, serialized_txn).await?; Ok(transaction) } async fn get_transactions( &self, project_id: &str, limit: usize, offset: usize, ) -> Result> { let prefix = Self::transactions_prefix(project_id); let mut transactions: Vec = self .scan_prefix_values(&prefix) .await? .into_iter() .filter_map(|v| Self::deserialize(v.as_slice()).ok()) .collect(); transactions.sort_by(|a, b| b.created_at.cmp(&a.created_at)); Ok(transactions.into_iter().skip(offset).take(limit).collect()) } async fn get_reservation(&self, id: &str) -> Result> { let key = Self::reservation_key(id); self.get_value_with_version(&key) .await? .map(|(_, value)| Self::deserialize(value.as_slice())) .transpose() } async fn create_reservation(&self, reservation: Reservation) -> Result { let key = Self::reservation_key(&reservation.id); let serialized_reservation = Self::serialize(&reservation)?; self.put_value(&key, serialized_reservation).await?; Ok(reservation) } async fn update_reservation(&self, reservation: Reservation) -> Result { let key = Self::reservation_key(&reservation.id); let serialized_reservation = Self::serialize(&reservation)?; self.put_value(&key, serialized_reservation).await?; Ok(reservation) } async fn delete_reservation(&self, id: &str) -> Result { let key = Self::reservation_key(id); self.delete_value(&key).await } async fn get_pending_reservations(&self, project_id: &str) -> Result> { let prefix = Self::reservations_prefix(); let reservations: Vec = self .scan_prefix_values(&prefix) .await? .into_iter() .filter_map(|v| Self::deserialize(v.as_slice()).ok()) .filter(|r: &Reservation| { r.status == creditservice_types::ReservationStatus::Pending && r.project_id == project_id }) .collect(); Ok(reservations) } async fn get_quota(&self, project_id: &str, resource_type: ResourceType) -> Result> { let key = Self::quota_key(project_id, resource_type); self.get_value_with_version(&key) .await? .map(|(_, value)| Self::deserialize(value.as_slice())) .transpose() } async fn set_quota(&self, quota: Quota) -> Result { let key = Self::quota_key("a.project_id, quota.resource_type); let serialized_quota = Self::serialize("a)?; self.put_value(&key, serialized_quota).await?; Ok(quota) } async fn list_quotas(&self, project_id: &str) -> Result> { let prefix = Self::quotas_prefix(project_id); let quotas: Vec = self .scan_prefix_values(&prefix) .await? .into_iter() .filter_map(|v| Self::deserialize(v.as_slice()).ok()) .collect(); Ok(quotas) } } fn prefix_end_key(prefix: &[u8]) -> Vec { let mut end_key = prefix.to_vec(); if let Some(last) = end_key.last_mut() { if *last == 0xff { end_key.push(0x00); } else { *last += 1; } } else { end_key.push(0xff); } end_key } fn normalize_flaredb_endpoint(endpoint: &str) -> String { endpoint .trim() .trim_start_matches("http://") .trim_start_matches("https://") .to_string() }