//! Backend abstraction for IAM storage //! //! Provides a unified interface for IAM data storage across //! ChainFire/FlareDB/SQL/in-memory backends. use async_trait::async_trait; use bytes::Bytes; use serde::{de::DeserializeOwned, Serialize}; use sqlx::pool::PoolOptions; use sqlx::{Pool, Postgres, Sqlite}; use tokio::sync::Mutex; use tokio::time::{timeout, Duration}; use tonic::Status; use iam_types::{Error, Result, StorageError}; use chainfire_client::{ CasOutcome, Client as ChainfireClient, ClientError as ChainfireClientError, }; use flaredb_client::RdbClient; const STORAGE_RPC_TIMEOUT: Duration = Duration::from_secs(5); /// Key-value pair with version #[derive(Debug, Clone)] pub struct KvPair { pub key: Bytes, pub value: Bytes, pub version: u64, } /// Result of a CAS (Compare-And-Swap) operation #[derive(Debug, Clone)] pub enum CasResult { /// CAS succeeded, returning the new version Success(u64), /// CAS failed due to version mismatch Conflict { expected: u64, actual: u64 }, /// Key not found (when expected version > 0) NotFound, } /// Backend trait for storage operations #[async_trait] pub trait StorageBackend: Send + Sync { /// Get a value by key async fn get(&self, key: &[u8]) -> Result>; /// Put a value (unconditional write) async fn put(&self, key: &[u8], value: &[u8]) -> Result; /// Compare-and-swap write /// - If expected_version is 0, only succeeds if key doesn't exist /// - Otherwise, only succeeds if current version matches expected_version async fn cas(&self, key: &[u8], expected_version: u64, value: &[u8]) -> Result; /// Delete a key async fn delete(&self, key: &[u8]) -> Result; /// Scan keys with a prefix async fn scan_prefix(&self, prefix: &[u8], limit: u32) -> Result>; /// Scan keys in a range [start, end) async fn scan_range(&self, start: &[u8], end: &[u8], limit: u32) -> Result>; /// Paginated scan by prefix; returns items and an optional cursor for the next page async fn scan_prefix_paged( &self, prefix: &[u8], start_after: Option<&[u8]>, limit: u32, ) -> Result<(Vec, Option)> { // Fallback implementation using scan_prefix (no pagination cursor) let mut effective_prefix = prefix.to_vec(); if let Some(start_after) = start_after { effective_prefix = start_after.to_vec(); effective_prefix.push(0); // ensure greater than start_after } let items = self.scan_prefix(&effective_prefix, limit).await?; Ok((items, None)) } /// Paginated range scan [start, end); returns items and an optional cursor for the next page async fn scan_range_paged( &self, start: &[u8], end: &[u8], start_after: Option<&[u8]>, limit: u32, ) -> Result<(Vec, Option)> { let mut effective_start = start.to_vec(); if let Some(after) = start_after { effective_start = after.to_vec(); effective_start.push(0); } let items = self.scan_range(&effective_start, end, limit).await?; Ok((items, None)) } } /// Backend configuration #[derive(Debug, Clone)] pub enum BackendConfig { /// Chainfire backend Chainfire { /// Chainfire endpoint addresses endpoints: Vec, }, /// FlareDB backend FlareDb { /// FlareDB endpoint address endpoint: String, /// ChainFire PD endpoint used for leader/region resolution pd_endpoint: String, /// Namespace for IAM data namespace: String, }, /// SQL backend (Postgres or SQLite) Sql { /// Database URL (postgres://... or sqlite://...) database_url: String, /// Whether single-node mode is enabled (required for SQLite) single_node: bool, }, /// In-memory backend (for testing) Memory, } /// Backend enum wrapping different implementations pub enum Backend { /// Chainfire backend Chainfire(ChainfireBackend), /// FlareDB backend FlareDb(FlareDbBackend), /// SQL backend Sql(SqlBackend), /// In-memory backend (for testing) Memory(MemoryBackend), } impl Backend { /// Create a new backend from configuration pub async fn new(config: BackendConfig) -> Result { match config { BackendConfig::Chainfire { endpoints } => { let backend = ChainfireBackend::new(endpoints).await?; Ok(Backend::Chainfire(backend)) } BackendConfig::FlareDb { endpoint, pd_endpoint, namespace, } => { let backend = FlareDbBackend::new(endpoint, pd_endpoint, namespace).await?; Ok(Backend::FlareDb(backend)) } BackendConfig::Sql { database_url, single_node, } => { let backend = SqlBackend::new(database_url, single_node).await?; Ok(Backend::Sql(backend)) } BackendConfig::Memory => Ok(Backend::Memory(MemoryBackend::new())), } } /// Create an in-memory backend for testing pub fn memory() -> Self { Backend::Memory(MemoryBackend::new()) } } #[async_trait] impl StorageBackend for Backend { async fn get(&self, key: &[u8]) -> Result> { match self { Backend::Chainfire(b) => b.get(key).await, Backend::FlareDb(b) => b.get(key).await, Backend::Sql(b) => b.get(key).await, Backend::Memory(b) => b.get(key).await, } } async fn put(&self, key: &[u8], value: &[u8]) -> Result { match self { Backend::Chainfire(b) => b.put(key, value).await, Backend::FlareDb(b) => b.put(key, value).await, Backend::Sql(b) => b.put(key, value).await, Backend::Memory(b) => b.put(key, value).await, } } async fn cas(&self, key: &[u8], expected_version: u64, value: &[u8]) -> Result { match self { Backend::Chainfire(b) => b.cas(key, expected_version, value).await, Backend::FlareDb(b) => b.cas(key, expected_version, value).await, Backend::Sql(b) => b.cas(key, expected_version, value).await, Backend::Memory(b) => b.cas(key, expected_version, value).await, } } async fn delete(&self, key: &[u8]) -> Result { match self { Backend::Chainfire(b) => b.delete(key).await, Backend::FlareDb(b) => b.delete(key).await, Backend::Sql(b) => b.delete(key).await, Backend::Memory(b) => b.delete(key).await, } } async fn scan_prefix(&self, prefix: &[u8], limit: u32) -> Result> { match self { Backend::Chainfire(b) => b.scan_prefix(prefix, limit).await, Backend::FlareDb(b) => b.scan_prefix(prefix, limit).await, Backend::Sql(b) => b.scan_prefix(prefix, limit).await, Backend::Memory(b) => b.scan_prefix(prefix, limit).await, } } async fn scan_range(&self, start: &[u8], end: &[u8], limit: u32) -> Result> { match self { Backend::Chainfire(b) => b.scan_range(start, end, limit).await, Backend::FlareDb(b) => b.scan_range(start, end, limit).await, Backend::Sql(b) => b.scan_range(start, end, limit).await, Backend::Memory(b) => b.scan_range(start, end, limit).await, } } async fn scan_prefix_paged( &self, prefix: &[u8], start_after: Option<&[u8]>, limit: u32, ) -> Result<(Vec, Option)> { match self { Backend::Chainfire(b) => b.scan_prefix_paged(prefix, start_after, limit).await, Backend::FlareDb(b) => b.scan_prefix_paged(prefix, start_after, limit).await, Backend::Sql(b) => b.scan_prefix_paged(prefix, start_after, limit).await, Backend::Memory(b) => b.scan_prefix_paged(prefix, start_after, limit).await, } } async fn scan_range_paged( &self, start: &[u8], end: &[u8], start_after: Option<&[u8]>, limit: u32, ) -> Result<(Vec, Option)> { match self { Backend::Chainfire(b) => b.scan_range_paged(start, end, start_after, limit).await, Backend::FlareDb(b) => b.scan_range_paged(start, end, start_after, limit).await, Backend::Sql(b) => b.scan_range_paged(start, end, start_after, limit).await, Backend::Memory(b) => b.scan_range_paged(start, end, start_after, limit).await, } } } // ============================================================================ // Chainfire Backend Implementation // ============================================================================ /// Chainfire backend implementation pub struct ChainfireBackend { client: Mutex, } impl ChainfireBackend { /// Create a new Chainfire backend pub async fn new(endpoints: Vec) -> Result { let client = Self::connect_any(&endpoints).await?; Ok(Self { client: Mutex::new(client), }) } async fn connect_any(endpoints: &[String]) -> Result { let mut last_err = None; for ep in endpoints { let addr = if ep.starts_with("http://") || ep.starts_with("https://") { ep.clone() } else { format!("http://{}", ep) }; match ChainfireClient::connect(addr.clone()).await { Ok(client) => return Ok(client), Err(e) => { last_err = Some(e); } } } Err(Error::Storage(StorageError::Connection( last_err .map(|e| e.to_string()) .unwrap_or_else(|| "no endpoints available".into()), ))) } } #[async_trait] impl StorageBackend for ChainfireBackend { async fn get(&self, key: &[u8]) -> Result> { let mut client = timeout(STORAGE_RPC_TIMEOUT, self.client.lock()) .await .map_err(|_| Error::Storage(StorageError::Timeout))?; let result = timeout(STORAGE_RPC_TIMEOUT, client.get_with_revision(key)) .await .map_err(|_| Error::Storage(StorageError::Timeout))? .map_err(map_chainfire_error)?; Ok(result.map(|(v, rev)| (Bytes::from(v), rev))) } async fn put(&self, key: &[u8], value: &[u8]) -> Result { let mut client = timeout(STORAGE_RPC_TIMEOUT, self.client.lock()) .await .map_err(|_| Error::Storage(StorageError::Timeout))?; timeout(STORAGE_RPC_TIMEOUT, client.put(key, value)) .await .map_err(|_| Error::Storage(StorageError::Timeout))? .map_err(map_chainfire_error) } async fn cas(&self, key: &[u8], expected_version: u64, value: &[u8]) -> Result { let mut client = timeout(STORAGE_RPC_TIMEOUT, self.client.lock()) .await .map_err(|_| Error::Storage(StorageError::Timeout))?; let outcome: CasOutcome = timeout( STORAGE_RPC_TIMEOUT, client.compare_and_swap(key, expected_version, value), ) .await .map_err(|_| Error::Storage(StorageError::Timeout))? .map_err(map_chainfire_error)?; if outcome.success { return Ok(CasResult::Success(outcome.new_version)); } if expected_version == 0 { if outcome.current_version == 0 { Ok(CasResult::NotFound) } else { Ok(CasResult::Conflict { expected: 0, actual: outcome.current_version, }) } } else { Ok(CasResult::Conflict { expected: expected_version, actual: outcome.current_version, }) } } async fn delete(&self, key: &[u8]) -> Result { let mut client = timeout(STORAGE_RPC_TIMEOUT, self.client.lock()) .await .map_err(|_| Error::Storage(StorageError::Timeout))?; timeout(STORAGE_RPC_TIMEOUT, client.delete(key)) .await .map_err(|_| Error::Storage(StorageError::Timeout))? .map_err(map_chainfire_error) } async fn scan_prefix(&self, prefix: &[u8], limit: u32) -> Result> { let mut client = timeout(STORAGE_RPC_TIMEOUT, self.client.lock()) .await .map_err(|_| Error::Storage(StorageError::Timeout))?; let (results, _) = timeout(STORAGE_RPC_TIMEOUT, client.scan_prefix(prefix, limit as i64)) .await .map_err(|_| Error::Storage(StorageError::Timeout))? .map_err(map_chainfire_error)?; Ok(results .into_iter() .map(|(k, v, ver)| KvPair { key: Bytes::from(k), value: Bytes::from(v), version: ver, }) .collect()) } async fn scan_range(&self, start: &[u8], end: &[u8], limit: u32) -> Result> { let mut client = timeout(STORAGE_RPC_TIMEOUT, self.client.lock()) .await .map_err(|_| Error::Storage(StorageError::Timeout))?; let (results, _) = timeout(STORAGE_RPC_TIMEOUT, client.scan_range(start, end, limit as i64)) .await .map_err(|_| Error::Storage(StorageError::Timeout))? .map_err(map_chainfire_error)?; Ok(results .into_iter() .map(|(k, v, ver)| KvPair { key: Bytes::from(k), value: Bytes::from(v), version: ver, }) .collect()) } async fn scan_prefix_paged( &self, prefix: &[u8], start_after: Option<&[u8]>, limit: u32, ) -> Result<(Vec, Option)> { let mut start = prefix.to_vec(); if let Some(after) = start_after { start = after.to_vec(); start.push(0); } let end = prefix_end(prefix); let mut client = timeout(STORAGE_RPC_TIMEOUT, self.client.lock()) .await .map_err(|_| Error::Storage(StorageError::Timeout))?; let (results, next) = timeout(STORAGE_RPC_TIMEOUT, client.scan_range(&start, &end, limit as i64)) .await .map_err(|_| Error::Storage(StorageError::Timeout))? .map_err(map_chainfire_error)?; let kvs = results .into_iter() .map(|(k, v, ver)| KvPair { key: Bytes::from(k), value: Bytes::from(v), version: ver, }) .collect(); Ok((kvs, next.map(Bytes::from))) } async fn scan_range_paged( &self, start: &[u8], end: &[u8], start_after: Option<&[u8]>, limit: u32, ) -> Result<(Vec, Option)> { let mut effective_start = start.to_vec(); if let Some(after) = start_after { effective_start = after.to_vec(); effective_start.push(0); } let mut client = timeout(STORAGE_RPC_TIMEOUT, self.client.lock()) .await .map_err(|_| Error::Storage(StorageError::Timeout))?; let (results, next) = timeout( STORAGE_RPC_TIMEOUT, client.scan_range(&effective_start, end, limit as i64), ) .await .map_err(|_| Error::Storage(StorageError::Timeout))? .map_err(map_chainfire_error)?; let kvs = results .into_iter() .map(|(k, v, ver)| KvPair { key: Bytes::from(k), value: Bytes::from(v), version: ver, }) .collect(); Ok((kvs, next.map(Bytes::from))) } } fn map_chainfire_error(err: ChainfireClientError) -> Error { match err { ChainfireClientError::Connection(msg) => Error::Storage(StorageError::Connection(msg)), ChainfireClientError::Transport(e) => { Error::Storage(StorageError::Connection(e.to_string())) } ChainfireClientError::Rpc(status) => { Error::Storage(StorageError::Backend(status.to_string())) } other => Error::Storage(StorageError::Backend(other.to_string())), } } fn map_flaredb_error(err: Status) -> Error { if err.code() == tonic::Code::Unavailable { Error::Storage(StorageError::Connection(err.to_string())) } else { Error::Storage(StorageError::Backend(err.to_string())) } } fn prefix_end(prefix: &[u8]) -> Vec { let mut end = prefix.to_vec(); for i in (0..end.len()).rev() { if end[i] < 0xff { end[i] += 1; end.truncate(i + 1); return end; } } Vec::new() } // ============================================================================ // FlareDB Backend Implementation // ============================================================================ /// FlareDB backend implementation pub struct FlareDbBackend { client: Mutex, } impl FlareDbBackend { /// Create a new FlareDB backend pub async fn new(endpoint: String, pd_endpoint: String, namespace: String) -> Result { let client = RdbClient::connect_with_pd_namespace(endpoint, pd_endpoint, namespace) .await .map_err(|e| Error::Storage(StorageError::Connection(e.to_string())))?; Ok(Self { client: Mutex::new(client), }) } } #[async_trait] impl StorageBackend for FlareDbBackend { async fn get(&self, key: &[u8]) -> Result> { let mut client = timeout(STORAGE_RPC_TIMEOUT, self.client.lock()) .await .map_err(|_| Error::Storage(StorageError::Timeout))?; let res = timeout(STORAGE_RPC_TIMEOUT, client.cas_get(key.to_vec())) .await .map_err(|_| Error::Storage(StorageError::Timeout))? .map_err(map_flaredb_error)?; Ok(res.and_then(|(ver, val)| { if val.is_empty() { None } else { Some((Bytes::from(val), ver)) } })) } async fn put(&self, key: &[u8], value: &[u8]) -> Result { let key = key.to_vec(); let value = value.to_vec(); let mut attempts = 0; loop { // Get current version (treat tombstone as absent) let current = { let mut client = timeout(STORAGE_RPC_TIMEOUT, self.client.lock()) .await .map_err(|_| Error::Storage(StorageError::Timeout))?; timeout(STORAGE_RPC_TIMEOUT, client.cas_get(key.clone())) .await .map_err(|_| Error::Storage(StorageError::Timeout))? .map_err(map_flaredb_error)? }; let mut expected_version = 0; if let Some((ver, val)) = current { if !val.is_empty() { expected_version = ver; } } let (success, current_version, new_version) = { let mut client = timeout(STORAGE_RPC_TIMEOUT, self.client.lock()) .await .map_err(|_| Error::Storage(StorageError::Timeout))?; timeout( STORAGE_RPC_TIMEOUT, client.cas(key.clone(), value.clone(), expected_version), ) .await .map_err(|_| Error::Storage(StorageError::Timeout))? .map_err(map_flaredb_error)? }; if success { return Ok(new_version); } attempts += 1; if attempts >= 3 { return Err(Error::Storage(StorageError::CasConflict { expected: expected_version, actual: current_version, })); } } } async fn cas(&self, key: &[u8], expected_version: u64, value: &[u8]) -> Result { let mut client = timeout(STORAGE_RPC_TIMEOUT, self.client.lock()) .await .map_err(|_| Error::Storage(StorageError::Timeout))?; let (success, current_version, new_version) = timeout( STORAGE_RPC_TIMEOUT, client.cas(key.to_vec(), value.to_vec(), expected_version), ) .await .map_err(|_| Error::Storage(StorageError::Timeout))? .map_err(map_flaredb_error)?; if success { Ok(CasResult::Success(new_version)) } else if expected_version == 0 { if current_version == 0 { Ok(CasResult::NotFound) } else { Ok(CasResult::Conflict { expected: 0, actual: current_version, }) } } else { Ok(CasResult::Conflict { expected: expected_version, actual: current_version, }) } } async fn delete(&self, key: &[u8]) -> Result { // FlareDB does not expose a delete; use a tombstone (empty value) let (current_version, value) = { let mut client = timeout(STORAGE_RPC_TIMEOUT, self.client.lock()) .await .map_err(|_| Error::Storage(StorageError::Timeout))?; let current = timeout(STORAGE_RPC_TIMEOUT, client.cas_get(key.to_vec())) .await .map_err(|_| Error::Storage(StorageError::Timeout))? .map_err(map_flaredb_error)?; match current { Some((ver, val)) => (ver, val), None => return Ok(false), } }; if value.is_empty() { return Ok(false); } let mut client = timeout(STORAGE_RPC_TIMEOUT, self.client.lock()) .await .map_err(|_| Error::Storage(StorageError::Timeout))?; let (success, _, _) = timeout( STORAGE_RPC_TIMEOUT, client.cas(key.to_vec(), Vec::new(), current_version), ) .await .map_err(|_| Error::Storage(StorageError::Timeout))? .map_err(map_flaredb_error)?; Ok(success) } async fn scan_prefix(&self, prefix: &[u8], limit: u32) -> Result> { let start = prefix.to_vec(); let end = prefix_end(prefix); let mut client = timeout(STORAGE_RPC_TIMEOUT, self.client.lock()) .await .map_err(|_| Error::Storage(StorageError::Timeout))?; let (entries, _) = timeout(STORAGE_RPC_TIMEOUT, client.cas_scan(start, end, limit)) .await .map_err(|_| Error::Storage(StorageError::Timeout))? .map_err(map_flaredb_error)?; Ok(entries .into_iter() .filter(|(_, val, _)| !val.is_empty()) .map(|(k, v, ver)| KvPair { key: Bytes::from(k), value: Bytes::from(v), version: ver, }) .collect()) } async fn scan_range(&self, start: &[u8], end: &[u8], limit: u32) -> Result> { let mut client = timeout(STORAGE_RPC_TIMEOUT, self.client.lock()) .await .map_err(|_| Error::Storage(StorageError::Timeout))?; let (entries, _) = timeout( STORAGE_RPC_TIMEOUT, client.cas_scan(start.to_vec(), end.to_vec(), limit), ) .await .map_err(|_| Error::Storage(StorageError::Timeout))? .map_err(map_flaredb_error)?; Ok(entries .into_iter() .filter(|(_, val, _)| !val.is_empty()) .map(|(k, v, ver)| KvPair { key: Bytes::from(k), value: Bytes::from(v), version: ver, }) .collect()) } async fn scan_prefix_paged( &self, prefix: &[u8], start_after: Option<&[u8]>, limit: u32, ) -> Result<(Vec, Option)> { let mut start = prefix.to_vec(); if let Some(after) = start_after { start = after.to_vec(); start.push(0); } let end = prefix_end(prefix); let mut client = timeout(STORAGE_RPC_TIMEOUT, self.client.lock()) .await .map_err(|_| Error::Storage(StorageError::Timeout))?; let (entries, next) = timeout(STORAGE_RPC_TIMEOUT, client.cas_scan(start, end, limit)) .await .map_err(|_| Error::Storage(StorageError::Timeout))? .map_err(map_flaredb_error)?; let kvs = entries .into_iter() .filter(|(_, val, _)| !val.is_empty()) .map(|(k, v, ver)| KvPair { key: Bytes::from(k), value: Bytes::from(v), version: ver, }) .collect(); Ok((kvs, next.map(Bytes::from))) } async fn scan_range_paged( &self, start: &[u8], end: &[u8], start_after: Option<&[u8]>, limit: u32, ) -> Result<(Vec, Option)> { let mut s = start.to_vec(); if let Some(after) = start_after { s = after.to_vec(); s.push(0); } let mut client = timeout(STORAGE_RPC_TIMEOUT, self.client.lock()) .await .map_err(|_| Error::Storage(StorageError::Timeout))?; let (entries, next) = timeout( STORAGE_RPC_TIMEOUT, client.cas_scan(s, end.to_vec(), limit), ) .await .map_err(|_| Error::Storage(StorageError::Timeout))? .map_err(map_flaredb_error)?; let kvs = entries .into_iter() .filter(|(_, val, _)| !val.is_empty()) .map(|(k, v, ver)| KvPair { key: Bytes::from(k), value: Bytes::from(v), version: ver, }) .collect(); Ok((kvs, next.map(Bytes::from))) } } // ============================================================================ // SQL Backend Implementation // ============================================================================ enum SqlBackendKind { Postgres(Pool), Sqlite(Pool), } /// SQL backend implementation (Postgres/SQLite) pub struct SqlBackend { backend: SqlBackendKind, } impl SqlBackend { /// Create a new SQL backend. pub async fn new(database_url: String, single_node: bool) -> Result { let url = database_url.trim(); if url.is_empty() { return Err(Error::Storage(StorageError::Backend( "database URL is empty".to_string(), ))); } if url.starts_with("postgres://") || url.starts_with("postgresql://") { let pool = PoolOptions::::new() .max_connections(10) .connect(url) .await .map_err(|e| Error::Storage(StorageError::Connection(e.to_string())))?; Self::ensure_schema_postgres(&pool).await?; return Ok(Self { backend: SqlBackendKind::Postgres(pool), }); } if url.starts_with("sqlite:") { if !single_node { return Err(Error::Storage(StorageError::Backend( "SQLite is allowed only in single-node mode".to_string(), ))); } if url.contains(":memory:") { return Err(Error::Storage(StorageError::Backend( "In-memory SQLite is not allowed".to_string(), ))); } let pool = PoolOptions::::new() .max_connections(1) .connect(url) .await .map_err(|e| Error::Storage(StorageError::Connection(e.to_string())))?; Self::ensure_schema_sqlite(&pool).await?; return Ok(Self { backend: SqlBackendKind::Sqlite(pool), }); } Err(Error::Storage(StorageError::Backend( "Unsupported database URL (use postgres://, postgresql://, or sqlite:)".to_string(), ))) } async fn ensure_schema_postgres(pool: &Pool) -> Result<()> { sqlx::query( "CREATE TABLE IF NOT EXISTS iam_kv ( key TEXT PRIMARY KEY, value BYTEA NOT NULL, version BIGINT NOT NULL )", ) .execute(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; Ok(()) } async fn ensure_schema_sqlite(pool: &Pool) -> Result<()> { sqlx::query( "CREATE TABLE IF NOT EXISTS iam_kv ( key TEXT PRIMARY KEY, value BLOB NOT NULL, version INTEGER NOT NULL )", ) .execute(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; Ok(()) } fn key_to_text(key: &[u8]) -> Result<&str> { std::str::from_utf8(key).map_err(|e| { Error::Storage(StorageError::Backend(format!( "Key must be UTF-8 for SQL backend: {}", e ))) }) } fn prefix_like(prefix: &[u8]) -> Result { Ok(format!("{}%", Self::key_to_text(prefix)?)) } fn row_to_kv(key: String, value: Vec, version: i64) -> Result { let version = u64::try_from(version).map_err(|e| { Error::Storage(StorageError::Backend(format!("Invalid version in SQL row: {}", e))) })?; Ok(KvPair { key: Bytes::from(key.into_bytes()), value: Bytes::from(value), version, }) } } #[async_trait] impl StorageBackend for SqlBackend { async fn get(&self, key: &[u8]) -> Result> { let key = Self::key_to_text(key)?; match &self.backend { SqlBackendKind::Postgres(pool) => { let row: Option<(Vec, i64)> = sqlx::query_as("SELECT value, version FROM iam_kv WHERE key = $1") .bind(key) .fetch_optional(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; match row { Some((value, version)) => Ok(Some(( Bytes::from(value), u64::try_from(version).map_err(|e| { Error::Storage(StorageError::Backend(format!( "Invalid version in SQL row: {}", e ))) })?, ))), None => Ok(None), } } SqlBackendKind::Sqlite(pool) => { let row: Option<(Vec, i64)> = sqlx::query_as("SELECT value, version FROM iam_kv WHERE key = ?1") .bind(key) .fetch_optional(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; match row { Some((value, version)) => Ok(Some(( Bytes::from(value), u64::try_from(version).map_err(|e| { Error::Storage(StorageError::Backend(format!( "Invalid version in SQL row: {}", e ))) })?, ))), None => Ok(None), } } } } async fn put(&self, key: &[u8], value: &[u8]) -> Result { let key = Self::key_to_text(key)?; match &self.backend { SqlBackendKind::Postgres(pool) => { let version: i64 = sqlx::query_scalar( "INSERT INTO iam_kv (key, value, version) VALUES ($1, $2, 1) ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value, version = iam_kv.version + 1 RETURNING version", ) .bind(key) .bind(value) .fetch_one(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; u64::try_from(version).map_err(|e| { Error::Storage(StorageError::Backend(format!( "Invalid version in SQL row: {}", e ))) }) } SqlBackendKind::Sqlite(pool) => { let version: i64 = sqlx::query_scalar( "INSERT INTO iam_kv (key, value, version) VALUES (?1, ?2, 1) ON CONFLICT(key) DO UPDATE SET value = excluded.value, version = iam_kv.version + 1 RETURNING version", ) .bind(key) .bind(value) .fetch_one(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; u64::try_from(version).map_err(|e| { Error::Storage(StorageError::Backend(format!( "Invalid version in SQL row: {}", e ))) }) } } } async fn cas(&self, key: &[u8], expected_version: u64, value: &[u8]) -> Result { let key = Self::key_to_text(key)?; if expected_version == 0 { return match &self.backend { SqlBackendKind::Postgres(pool) => { let inserted: Option = sqlx::query_scalar( "INSERT INTO iam_kv (key, value, version) VALUES ($1, $2, 1) ON CONFLICT DO NOTHING RETURNING version", ) .bind(key) .bind(value) .fetch_optional(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; if let Some(v) = inserted { Ok(CasResult::Success(u64::try_from(v).map_err(|e| { Error::Storage(StorageError::Backend(format!( "Invalid version in SQL row: {}", e ))) })?)) } else { let actual: Option = sqlx::query_scalar("SELECT version FROM iam_kv WHERE key = $1") .bind(key) .fetch_optional(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; match actual { Some(v) => Ok(CasResult::Conflict { expected: 0, actual: u64::try_from(v).map_err(|e| { Error::Storage(StorageError::Backend(format!( "Invalid version in SQL row: {}", e ))) })?, }), None => Ok(CasResult::NotFound), } } } SqlBackendKind::Sqlite(pool) => { let result = sqlx::query("INSERT OR IGNORE INTO iam_kv (key, value, version) VALUES (?1, ?2, 1)") .bind(key) .bind(value) .execute(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; if result.rows_affected() > 0 { Ok(CasResult::Success(1)) } else { let actual: Option = sqlx::query_scalar("SELECT version FROM iam_kv WHERE key = ?1") .bind(key) .fetch_optional(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; match actual { Some(v) => Ok(CasResult::Conflict { expected: 0, actual: u64::try_from(v).map_err(|e| { Error::Storage(StorageError::Backend(format!( "Invalid version in SQL row: {}", e ))) })?, }), None => Ok(CasResult::NotFound), } } } }; } let expected = i64::try_from(expected_version).map_err(|e| { Error::Storage(StorageError::Backend(format!( "expected_version out of range for SQL backend: {}", e ))) })?; match &self.backend { SqlBackendKind::Postgres(pool) => { let updated: Option = sqlx::query_scalar( "UPDATE iam_kv SET value = $2, version = version + 1 WHERE key = $1 AND version = $3 RETURNING version", ) .bind(key) .bind(value) .bind(expected) .fetch_optional(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; if let Some(v) = updated { Ok(CasResult::Success(u64::try_from(v).map_err(|e| { Error::Storage(StorageError::Backend(format!( "Invalid version in SQL row: {}", e ))) })?)) } else { let actual: Option = sqlx::query_scalar("SELECT version FROM iam_kv WHERE key = $1") .bind(key) .fetch_optional(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; match actual { Some(v) => Ok(CasResult::Conflict { expected: expected_version, actual: u64::try_from(v).map_err(|e| { Error::Storage(StorageError::Backend(format!( "Invalid version in SQL row: {}", e ))) })?, }), None => Ok(CasResult::NotFound), } } } SqlBackendKind::Sqlite(pool) => { let updated: Option = sqlx::query_scalar( "UPDATE iam_kv SET value = ?2, version = version + 1 WHERE key = ?1 AND version = ?3 RETURNING version", ) .bind(key) .bind(value) .bind(expected) .fetch_optional(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; if let Some(v) = updated { Ok(CasResult::Success(u64::try_from(v).map_err(|e| { Error::Storage(StorageError::Backend(format!( "Invalid version in SQL row: {}", e ))) })?)) } else { let actual: Option = sqlx::query_scalar("SELECT version FROM iam_kv WHERE key = ?1") .bind(key) .fetch_optional(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; match actual { Some(v) => Ok(CasResult::Conflict { expected: expected_version, actual: u64::try_from(v).map_err(|e| { Error::Storage(StorageError::Backend(format!( "Invalid version in SQL row: {}", e ))) })?, }), None => Ok(CasResult::NotFound), } } } } } async fn delete(&self, key: &[u8]) -> Result { let key = Self::key_to_text(key)?; let rows = match &self.backend { SqlBackendKind::Postgres(pool) => { sqlx::query("DELETE FROM iam_kv WHERE key = $1") .bind(key) .execute(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))? .rows_affected() } SqlBackendKind::Sqlite(pool) => { sqlx::query("DELETE FROM iam_kv WHERE key = ?1") .bind(key) .execute(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))? .rows_affected() } }; Ok(rows > 0) } async fn scan_prefix(&self, prefix: &[u8], limit: u32) -> Result> { let like = Self::prefix_like(prefix)?; match &self.backend { SqlBackendKind::Postgres(pool) => { let rows: Vec<(String, Vec, i64)> = sqlx::query_as( "SELECT key, value, version FROM iam_kv WHERE key LIKE $1 ORDER BY key LIMIT $2", ) .bind(like) .bind(i64::from(limit)) .fetch_all(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; rows.into_iter() .map(|(k, v, ver)| Self::row_to_kv(k, v, ver)) .collect() } SqlBackendKind::Sqlite(pool) => { let rows: Vec<(String, Vec, i64)> = sqlx::query_as( "SELECT key, value, version FROM iam_kv WHERE key LIKE ?1 ORDER BY key LIMIT ?2", ) .bind(like) .bind(i64::from(limit)) .fetch_all(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; rows.into_iter() .map(|(k, v, ver)| Self::row_to_kv(k, v, ver)) .collect() } } } async fn scan_range(&self, start: &[u8], end: &[u8], limit: u32) -> Result> { let start = Self::key_to_text(start)?; let end = Self::key_to_text(end)?; match &self.backend { SqlBackendKind::Postgres(pool) => { let rows: Vec<(String, Vec, i64)> = sqlx::query_as( "SELECT key, value, version FROM iam_kv WHERE key >= $1 AND key < $2 ORDER BY key LIMIT $3", ) .bind(start) .bind(end) .bind(i64::from(limit)) .fetch_all(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; rows.into_iter() .map(|(k, v, ver)| Self::row_to_kv(k, v, ver)) .collect() } SqlBackendKind::Sqlite(pool) => { let rows: Vec<(String, Vec, i64)> = sqlx::query_as( "SELECT key, value, version FROM iam_kv WHERE key >= ?1 AND key < ?2 ORDER BY key LIMIT ?3", ) .bind(start) .bind(end) .bind(i64::from(limit)) .fetch_all(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; rows.into_iter() .map(|(k, v, ver)| Self::row_to_kv(k, v, ver)) .collect() } } } async fn scan_prefix_paged( &self, prefix: &[u8], start_after: Option<&[u8]>, limit: u32, ) -> Result<(Vec, Option)> { if limit == 0 { return Ok((Vec::new(), None)); } let like = Self::prefix_like(prefix)?; let items = match (&self.backend, start_after) { (SqlBackendKind::Postgres(pool), Some(after)) => { let after = Self::key_to_text(after)?; let rows: Vec<(String, Vec, i64)> = sqlx::query_as( "SELECT key, value, version FROM iam_kv WHERE key LIKE $1 AND key > $2 ORDER BY key LIMIT $3", ) .bind(like) .bind(after) .bind(i64::from(limit)) .fetch_all(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; rows.into_iter() .map(|(k, v, ver)| Self::row_to_kv(k, v, ver)) .collect::>>()? } (SqlBackendKind::Postgres(pool), None) => { let rows: Vec<(String, Vec, i64)> = sqlx::query_as( "SELECT key, value, version FROM iam_kv WHERE key LIKE $1 ORDER BY key LIMIT $2", ) .bind(like) .bind(i64::from(limit)) .fetch_all(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; rows.into_iter() .map(|(k, v, ver)| Self::row_to_kv(k, v, ver)) .collect::>>()? } (SqlBackendKind::Sqlite(pool), Some(after)) => { let after = Self::key_to_text(after)?; let rows: Vec<(String, Vec, i64)> = sqlx::query_as( "SELECT key, value, version FROM iam_kv WHERE key LIKE ?1 AND key > ?2 ORDER BY key LIMIT ?3", ) .bind(like) .bind(after) .bind(i64::from(limit)) .fetch_all(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; rows.into_iter() .map(|(k, v, ver)| Self::row_to_kv(k, v, ver)) .collect::>>()? } (SqlBackendKind::Sqlite(pool), None) => { let rows: Vec<(String, Vec, i64)> = sqlx::query_as( "SELECT key, value, version FROM iam_kv WHERE key LIKE ?1 ORDER BY key LIMIT ?2", ) .bind(like) .bind(i64::from(limit)) .fetch_all(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; rows.into_iter() .map(|(k, v, ver)| Self::row_to_kv(k, v, ver)) .collect::>>()? } }; let next = if items.len() as u32 == limit { items.last().map(|kv| kv.key.clone()) } else { None }; Ok((items, next)) } async fn scan_range_paged( &self, start: &[u8], end: &[u8], start_after: Option<&[u8]>, limit: u32, ) -> Result<(Vec, Option)> { if limit == 0 { return Ok((Vec::new(), None)); } let start = Self::key_to_text(start)?; let end = Self::key_to_text(end)?; let items = match (&self.backend, start_after) { (SqlBackendKind::Postgres(pool), Some(after)) => { let after = Self::key_to_text(after)?; let rows: Vec<(String, Vec, i64)> = sqlx::query_as( "SELECT key, value, version FROM iam_kv WHERE key > $1 AND key < $2 ORDER BY key LIMIT $3", ) .bind(after) .bind(end) .bind(i64::from(limit)) .fetch_all(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; rows.into_iter() .map(|(k, v, ver)| Self::row_to_kv(k, v, ver)) .collect::>>()? } (SqlBackendKind::Postgres(pool), None) => { let rows: Vec<(String, Vec, i64)> = sqlx::query_as( "SELECT key, value, version FROM iam_kv WHERE key >= $1 AND key < $2 ORDER BY key LIMIT $3", ) .bind(start) .bind(end) .bind(i64::from(limit)) .fetch_all(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; rows.into_iter() .map(|(k, v, ver)| Self::row_to_kv(k, v, ver)) .collect::>>()? } (SqlBackendKind::Sqlite(pool), Some(after)) => { let after = Self::key_to_text(after)?; let rows: Vec<(String, Vec, i64)> = sqlx::query_as( "SELECT key, value, version FROM iam_kv WHERE key > ?1 AND key < ?2 ORDER BY key LIMIT ?3", ) .bind(after) .bind(end) .bind(i64::from(limit)) .fetch_all(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; rows.into_iter() .map(|(k, v, ver)| Self::row_to_kv(k, v, ver)) .collect::>>()? } (SqlBackendKind::Sqlite(pool), None) => { let rows: Vec<(String, Vec, i64)> = sqlx::query_as( "SELECT key, value, version FROM iam_kv WHERE key >= ?1 AND key < ?2 ORDER BY key LIMIT ?3", ) .bind(start) .bind(end) .bind(i64::from(limit)) .fetch_all(pool) .await .map_err(|e| Error::Storage(StorageError::Backend(e.to_string())))?; rows.into_iter() .map(|(k, v, ver)| Self::row_to_kv(k, v, ver)) .collect::>>()? } }; let next = if items.len() as u32 == limit { items.last().map(|kv| kv.key.clone()) } else { None }; Ok((items, next)) } } // ============================================================================ // In-Memory Backend Implementation (for testing) // ============================================================================ use std::collections::BTreeMap; use std::sync::RwLock; type MemKvMap = BTreeMap, (Vec, u64)>; /// In-memory backend for testing pub struct MemoryBackend { data: RwLock, version_counter: RwLock, } impl MemoryBackend { /// Create a new in-memory backend pub fn new() -> Self { Self { data: RwLock::new(BTreeMap::new()), version_counter: RwLock::new(0), } } fn next_version(&self) -> u64 { let mut counter = self.version_counter.write().unwrap(); *counter += 1; *counter } } impl Default for MemoryBackend { fn default() -> Self { Self::new() } } #[async_trait] impl StorageBackend for MemoryBackend { async fn get(&self, key: &[u8]) -> Result> { let data = self.data.read().unwrap(); Ok(data .get(key) .map(|(v, ver)| (Bytes::copy_from_slice(v), *ver))) } async fn put(&self, key: &[u8], value: &[u8]) -> Result { let version = self.next_version(); let mut data = self.data.write().unwrap(); data.insert(key.to_vec(), (value.to_vec(), version)); Ok(version) } async fn cas(&self, key: &[u8], expected_version: u64, value: &[u8]) -> Result { let mut data = self.data.write().unwrap(); match data.get(key) { Some((_, current_version)) => { if *current_version != expected_version { return Ok(CasResult::Conflict { expected: expected_version, actual: *current_version, }); } } None => { if expected_version != 0 { return Ok(CasResult::NotFound); } } } let version = self.next_version(); data.insert(key.to_vec(), (value.to_vec(), version)); Ok(CasResult::Success(version)) } async fn delete(&self, key: &[u8]) -> Result { let mut data = self.data.write().unwrap(); Ok(data.remove(key).is_some()) } async fn scan_prefix(&self, prefix: &[u8], limit: u32) -> Result> { let data = self.data.read().unwrap(); let mut results = Vec::new(); for (k, (v, ver)) in data.range(prefix.to_vec()..) { if !k.starts_with(prefix) { break; } results.push(KvPair { key: Bytes::copy_from_slice(k), value: Bytes::copy_from_slice(v), version: *ver, }); if results.len() >= limit as usize { break; } } Ok(results) } async fn scan_range(&self, start: &[u8], end: &[u8], limit: u32) -> Result> { let data = self.data.read().unwrap(); let mut results = Vec::new(); for (k, (v, ver)) in data.range(start.to_vec()..end.to_vec()) { results.push(KvPair { key: Bytes::copy_from_slice(k), value: Bytes::copy_from_slice(v), version: *ver, }); if results.len() >= limit as usize { break; } } Ok(results) } async fn scan_prefix_paged( &self, prefix: &[u8], start_after: Option<&[u8]>, limit: u32, ) -> Result<(Vec, Option)> { let start_key = match start_after { Some(after) => { let mut k = after.to_vec(); k.push(0); k } None => prefix.to_vec(), }; let end_key = prefix_end(prefix); let items = self.scan_range(&start_key, &end_key, limit).await?; let next = if items.len() as u32 == limit { items.last().map(|kv| { let mut n = kv.key.to_vec(); n.push(0); Bytes::from(n) }) } else { None }; Ok((items, next)) } async fn scan_range_paged( &self, start: &[u8], end: &[u8], start_after: Option<&[u8]>, limit: u32, ) -> Result<(Vec, Option)> { let start_key = match start_after { Some(after) => { let mut k = after.to_vec(); k.push(0); k } None => start.to_vec(), }; let items = self.scan_range(&start_key, end, limit).await?; let next = if items.len() as u32 == limit { items.last().map(|kv| { let mut n = kv.key.to_vec(); n.push(0); Bytes::from(n) }) } else { None }; Ok((items, next)) } } // ============================================================================ // Helper functions for stores // ============================================================================ /// Helper trait for JSON serialization in stores pub trait JsonStore: Sync { fn backend(&self) -> &Backend; /// Get and deserialize a value fn get_json<'a, T: DeserializeOwned + Send + 'a>( &'a self, key: &'a [u8], ) -> impl std::future::Future>> + Send + 'a { async move { match self.backend().get(key).await? { Some((bytes, version)) => { let value: T = serde_json::from_slice(&bytes) .map_err(|e| Error::Serialization(e.to_string()))?; Ok(Some((value, version))) } None => Ok(None), } } } /// Serialize and put a value fn put_json<'a, T: Serialize + Send + Sync + 'a>( &'a self, key: &'a [u8], value: &'a T, ) -> impl std::future::Future> + Send + 'a { async move { let bytes = serde_json::to_vec(value).map_err(|e| Error::Serialization(e.to_string()))?; self.backend().put(key, &bytes).await } } /// Serialize and CAS a value fn cas_json<'a, T: Serialize + Send + Sync + 'a>( &'a self, key: &'a [u8], expected_version: u64, value: &'a T, ) -> impl std::future::Future> + Send + 'a { async move { let bytes = serde_json::to_vec(value).map_err(|e| Error::Serialization(e.to_string()))?; self.backend().cas(key, expected_version, &bytes).await } } } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn test_memory_backend_basic() { let backend = MemoryBackend::new(); // Test put and get let version = backend.put(b"key1", b"value1").await.unwrap(); assert!(version > 0); let result = backend.get(b"key1").await.unwrap(); assert!(result.is_some()); let (value, ver) = result.unwrap(); assert_eq!(&value[..], b"value1"); assert_eq!(ver, version); // Test get non-existent let result = backend.get(b"nonexistent").await.unwrap(); assert!(result.is_none()); } #[tokio::test] async fn test_memory_backend_cas() { let backend = MemoryBackend::new(); // CAS create (expected version 0) let result = backend.cas(b"key1", 0, b"value1").await.unwrap(); let version = match result { CasResult::Success(v) => v, _ => panic!("expected success"), }; // CAS update with correct version let result = backend.cas(b"key1", version, b"value2").await.unwrap(); assert!(matches!(result, CasResult::Success(_))); // CAS update with wrong version let result = backend.cas(b"key1", version, b"value3").await.unwrap(); assert!(matches!(result, CasResult::Conflict { .. })); // CAS create when key exists let result = backend.cas(b"key1", 0, b"value4").await.unwrap(); assert!(matches!(result, CasResult::Conflict { .. })); } #[tokio::test] async fn test_memory_backend_scan_prefix() { let backend = MemoryBackend::new(); backend.put(b"prefix/a", b"1").await.unwrap(); backend.put(b"prefix/b", b"2").await.unwrap(); backend.put(b"prefix/c", b"3").await.unwrap(); backend.put(b"other/x", b"4").await.unwrap(); let results = backend.scan_prefix(b"prefix/", 10).await.unwrap(); assert_eq!(results.len(), 3); // Test limit let results = backend.scan_prefix(b"prefix/", 2).await.unwrap(); assert_eq!(results.len(), 2); } #[tokio::test] async fn test_memory_backend_scan_prefix_paged() { let backend = MemoryBackend::new(); for i in 0..5u8 { let key = format!("prefix/{}", i); backend.put(key.as_bytes(), &[i]).await.unwrap(); } let (page1, cursor1) = backend .scan_prefix_paged(b"prefix/", None, 2) .await .unwrap(); assert_eq!(page1.len(), 2); assert!(cursor1.is_some()); let (page2, cursor2) = backend .scan_prefix_paged(b"prefix/", cursor1.as_deref(), 2) .await .unwrap(); assert_eq!(page2.len(), 2); assert!(cursor2.is_some()); let (page3, cursor3) = backend .scan_prefix_paged(b"prefix/", cursor2.as_deref(), 2) .await .unwrap(); assert_eq!(page3.len(), 1); assert!(cursor3.is_none()); let collected: Vec = page1 .iter() .chain(page2.iter()) .chain(page3.iter()) .map(|kv| kv.value[0]) .collect(); assert_eq!(collected.len(), 5); assert!(collected.contains(&0) && collected.contains(&4)); } #[tokio::test] async fn test_memory_backend_delete() { let backend = MemoryBackend::new(); backend.put(b"key1", b"value1").await.unwrap(); assert!(backend.get(b"key1").await.unwrap().is_some()); let deleted = backend.delete(b"key1").await.unwrap(); assert!(deleted); assert!(backend.get(b"key1").await.unwrap().is_none()); // Delete non-existent let deleted = backend.delete(b"key1").await.unwrap(); assert!(!deleted); } }