use std::collections::HashMap; use std::net::SocketAddr; use std::sync::{ atomic::{AtomicU64, Ordering}, Arc, }; use flaredb_client::RdbClient; use flaredb_proto::kvrpc::kv_cas_server::KvCas; use flaredb_proto::kvrpc::kv_cas_server::KvCasServer; use flaredb_proto::kvrpc::kv_raw_server::KvRaw; use flaredb_proto::kvrpc::kv_raw_server::KvRawServer; use flaredb_proto::kvrpc::{ CasRequest, CasResponse, GetRequest, GetResponse, RawGetRequest, RawGetResponse, RawPutRequest, RawPutResponse, RawScanRequest, RawScanResponse, ScanRequest, ScanResponse, }; use flaredb_proto::pdpb::pd_server::Pd; use flaredb_proto::pdpb::pd_server::PdServer; use flaredb_proto::pdpb::tso_server::Tso; use flaredb_proto::pdpb::tso_server::TsoServer; use flaredb_proto::pdpb::{ GetRegionRequest, GetRegionResponse, ListRegionsRequest, ListRegionsResponse, Region, RegisterStoreRequest, RegisterStoreResponse, Store, TsoRequest, TsoResponse, }; use tokio::net::TcpListener; use tokio::sync::{oneshot, Mutex}; use tokio_stream::wrappers::TcpListenerStream; use tonic::transport::Server; use tonic::{Request, Response, Status}; #[derive(Clone, Default)] struct TestKvService { raw: Arc, Vec>>>, cas: Arc, (u64, Vec)>>>, } #[tonic::async_trait] impl KvRaw for TestKvService { async fn raw_put( &self, request: Request, ) -> Result, Status> { let req = request.into_inner(); let mut raw = self.raw.lock().await; raw.insert(req.key, req.value); Ok(Response::new(RawPutResponse { success: true })) } async fn raw_get( &self, request: Request, ) -> Result, Status> { let req = request.into_inner(); let raw = self.raw.lock().await; if let Some(val) = raw.get(&req.key) { Ok(Response::new(RawGetResponse { found: true, value: val.clone(), })) } else { Ok(Response::new(RawGetResponse { found: false, value: Vec::new(), })) } } async fn raw_scan( &self, _request: Request, ) -> Result, Status> { Ok(Response::new(RawScanResponse { keys: vec![], values: vec![], has_more: false, next_key: vec![], })) } } #[tonic::async_trait] impl KvCas for TestKvService { async fn compare_and_swap( &self, request: Request, ) -> Result, Status> { let req = request.into_inner(); let mut cas = self.cas.lock().await; let (current_version, _) = cas.get(&req.key).cloned().unwrap_or((0, Vec::new())); if current_version != req.expected_version { return Ok(Response::new(CasResponse { success: false, current_version, new_version: 0, })); } let new_version = current_version + 1; cas.insert(req.key, (new_version, req.value)); Ok(Response::new(CasResponse { success: true, current_version, new_version, })) } async fn get(&self, request: Request) -> Result, Status> { let req = request.into_inner(); let cas = self.cas.lock().await; if let Some((ver, val)) = cas.get(&req.key) { Ok(Response::new(GetResponse { found: true, value: val.clone(), version: *ver, })) } else { Ok(Response::new(GetResponse { found: false, value: Vec::new(), version: 0, })) } } async fn scan(&self, _request: Request) -> Result, Status> { Ok(Response::new(ScanResponse { entries: vec![], has_more: false, next_key: vec![], })) } } #[derive(Clone)] struct TestPdService { region: Region, leader: Store, } #[tonic::async_trait] impl Pd for TestPdService { async fn register_store( &self, _request: Request, ) -> Result, Status> { Ok(Response::new(RegisterStoreResponse { store_id: self.leader.id, cluster_id: 1, })) } async fn get_region( &self, _request: Request, ) -> Result, Status> { Ok(Response::new(GetRegionResponse { region: Some(self.region.clone()), leader: Some(self.leader.clone()), })) } async fn list_regions( &self, _request: Request, ) -> Result, Status> { Ok(Response::new(ListRegionsResponse { regions: vec![self.region.clone()], stores: vec![self.leader.clone()], })) } } #[derive(Clone, Default)] struct TestTsoService { counter: Arc, } #[tonic::async_trait] impl Tso for TestTsoService { async fn get_timestamp( &self, request: Request, ) -> Result, Status> { let count = request.into_inner().count.max(1) as u64; let start = self.counter.fetch_add(count, Ordering::AcqRel) + 1; Ok(Response::new(TsoResponse { timestamp: start, count: count as u32, })) } } async fn start_kv_server( service: TestKvService, ) -> Result< (SocketAddr, oneshot::Sender<()>, tokio::task::JoinHandle<()>), Box, > { let listener = TcpListener::bind("127.0.0.1:0").await?; let addr = listener.local_addr()?; let incoming = TcpListenerStream::new(listener); let (tx, rx) = oneshot::channel(); let raw_service = service.clone(); let cas_service = service.clone(); let handle = tokio::spawn(async move { Server::builder() .add_service(KvRawServer::new(raw_service)) .add_service(KvCasServer::new(cas_service)) .serve_with_incoming_shutdown(incoming, async { let _ = rx.await; }) .await .unwrap(); }); Ok((addr, tx, handle)) } async fn start_pd_server( region: Region, leader: Store, ) -> Result< (SocketAddr, oneshot::Sender<()>, tokio::task::JoinHandle<()>), Box, > { let listener = TcpListener::bind("127.0.0.1:0").await?; let addr = listener.local_addr()?; let incoming = TcpListenerStream::new(listener); let (tx, rx) = oneshot::channel(); let tso_service = TestTsoService::default(); let pd_service = TestPdService { region, leader }; let handle = tokio::spawn(async move { Server::builder() .add_service(TsoServer::new(tso_service)) .add_service(PdServer::new(pd_service)) .serve_with_incoming_shutdown(incoming, async { let _ = rx.await; }) .await .unwrap(); }); Ok((addr, tx, handle)) } #[tokio::test(flavor = "multi_thread")] async fn test_rpc_connect() -> Result<(), Box> { let kv_service = TestKvService::default(); let (kv_addr, kv_shutdown, kv_handle) = start_kv_server(kv_service).await?; let leader = Store { id: 1, addr: kv_addr.to_string(), }; let region = Region { id: 1, start_key: Vec::new(), end_key: Vec::new(), peers: vec![1], leader_id: 1, }; let (pd_addr, pd_shutdown, pd_handle) = start_pd_server(region, leader).await?; let mut client = RdbClient::connect_with_pd(kv_addr.to_string(), pd_addr.to_string()).await?; let ts = client.get_tso().await?; assert!(ts > 0); client.raw_put(b"k1".to_vec(), b"v1".to_vec()).await?; let got = client.raw_get(b"k1".to_vec()).await?; assert_eq!(got, Some(b"v1".to_vec())); let (ok, current, new_version) = client.cas(b"cas_key".to_vec(), b"v1".to_vec(), 0).await?; assert!(ok); assert_eq!(current, 0); assert_eq!(new_version, 1); let (ok2, current2, _) = client.cas(b"cas_key".to_vec(), b"v2".to_vec(), 0).await?; assert!(!ok2); assert_eq!(current2, 1); let cas_val = client.cas_get(b"cas_key".to_vec()).await?; assert_eq!(cas_val, Some((1, b"v1".to_vec()))); let _ = kv_shutdown.send(()); let _ = pd_shutdown.send(()); kv_handle.await?; pd_handle.await?; Ok(()) }