photoncloud-monorepo/flaredb/crates/flaredb-client/tests/test_rpc_connect.rs
centra 5c6eb04a46 T036: Add VM cluster deployment configs for nixos-anywhere
- netboot-base.nix with SSH key auth
- Launch scripts for node01/02/03
- Node configuration.nix and disko.nix
- Nix modules for first-boot automation

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-11 09:59:19 +09:00

329 lines
9.5 KiB
Rust

use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc,
};
use flaredb_client::RdbClient;
use flaredb_proto::kvrpc::kv_cas_server::KvCas;
use flaredb_proto::kvrpc::kv_cas_server::KvCasServer;
use flaredb_proto::kvrpc::kv_raw_server::KvRaw;
use flaredb_proto::kvrpc::kv_raw_server::KvRawServer;
use flaredb_proto::kvrpc::{
CasRequest, CasResponse, DeleteRequest, DeleteResponse, GetRequest, GetResponse, RawDeleteRequest,
RawDeleteResponse, RawGetRequest, RawGetResponse, RawPutRequest, RawPutResponse, RawScanRequest,
RawScanResponse, ScanRequest, ScanResponse,
};
use flaredb_proto::pdpb::pd_server::Pd;
use flaredb_proto::pdpb::pd_server::PdServer;
use flaredb_proto::pdpb::tso_server::Tso;
use flaredb_proto::pdpb::tso_server::TsoServer;
use flaredb_proto::pdpb::{
GetRegionRequest, GetRegionResponse, ListRegionsRequest, ListRegionsResponse, Region,
RegisterStoreRequest, RegisterStoreResponse, Store, TsoRequest, TsoResponse,
};
use tokio::net::TcpListener;
use tokio::sync::{oneshot, Mutex};
use tokio_stream::wrappers::TcpListenerStream;
use tonic::transport::Server;
use tonic::{Request, Response, Status};
#[derive(Clone, Default)]
struct TestKvService {
raw: Arc<Mutex<HashMap<Vec<u8>, Vec<u8>>>>,
cas: Arc<Mutex<HashMap<Vec<u8>, (u64, Vec<u8>)>>>,
}
#[tonic::async_trait]
impl KvRaw for TestKvService {
async fn raw_put(
&self,
request: Request<RawPutRequest>,
) -> Result<Response<RawPutResponse>, Status> {
let req = request.into_inner();
let mut raw = self.raw.lock().await;
raw.insert(req.key, req.value);
Ok(Response::new(RawPutResponse { success: true }))
}
async fn raw_get(
&self,
request: Request<RawGetRequest>,
) -> Result<Response<RawGetResponse>, Status> {
let req = request.into_inner();
let raw = self.raw.lock().await;
if let Some(val) = raw.get(&req.key) {
Ok(Response::new(RawGetResponse {
found: true,
value: val.clone(),
}))
} else {
Ok(Response::new(RawGetResponse {
found: false,
value: Vec::new(),
}))
}
}
async fn raw_scan(
&self,
_request: Request<RawScanRequest>,
) -> Result<Response<RawScanResponse>, Status> {
Ok(Response::new(RawScanResponse {
keys: vec![],
values: vec![],
has_more: false,
next_key: vec![],
}))
}
async fn raw_delete(
&self,
request: Request<RawDeleteRequest>,
) -> Result<Response<RawDeleteResponse>, Status> {
let req = request.into_inner();
let mut raw = self.raw.lock().await;
let existed = raw.remove(&req.key).is_some();
Ok(Response::new(RawDeleteResponse {
success: true,
existed,
}))
}
}
#[tonic::async_trait]
impl KvCas for TestKvService {
async fn compare_and_swap(
&self,
request: Request<CasRequest>,
) -> Result<Response<CasResponse>, Status> {
let req = request.into_inner();
let mut cas = self.cas.lock().await;
let (current_version, _) = cas.get(&req.key).cloned().unwrap_or((0, Vec::new()));
if current_version != req.expected_version {
return Ok(Response::new(CasResponse {
success: false,
current_version,
new_version: 0,
}));
}
let new_version = current_version + 1;
cas.insert(req.key, (new_version, req.value));
Ok(Response::new(CasResponse {
success: true,
current_version,
new_version,
}))
}
async fn get(&self, request: Request<GetRequest>) -> Result<Response<GetResponse>, Status> {
let req = request.into_inner();
let cas = self.cas.lock().await;
if let Some((ver, val)) = cas.get(&req.key) {
Ok(Response::new(GetResponse {
found: true,
value: val.clone(),
version: *ver,
}))
} else {
Ok(Response::new(GetResponse {
found: false,
value: Vec::new(),
version: 0,
}))
}
}
async fn scan(&self, _request: Request<ScanRequest>) -> Result<Response<ScanResponse>, Status> {
Ok(Response::new(ScanResponse {
entries: vec![],
has_more: false,
next_key: vec![],
}))
}
async fn delete(
&self,
request: Request<DeleteRequest>,
) -> Result<Response<DeleteResponse>, Status> {
let req = request.into_inner();
let mut cas = self.cas.lock().await;
let (current_version, existed) = if let Some((ver, _)) = cas.remove(&req.key) {
(ver, true)
} else {
(0, false)
};
Ok(Response::new(DeleteResponse {
success: true,
existed,
current_version,
}))
}
}
#[derive(Clone)]
struct TestPdService {
region: Region,
leader: Store,
}
#[tonic::async_trait]
impl Pd for TestPdService {
async fn register_store(
&self,
_request: Request<RegisterStoreRequest>,
) -> Result<Response<RegisterStoreResponse>, Status> {
Ok(Response::new(RegisterStoreResponse {
store_id: self.leader.id,
cluster_id: 1,
}))
}
async fn get_region(
&self,
_request: Request<GetRegionRequest>,
) -> Result<Response<GetRegionResponse>, Status> {
Ok(Response::new(GetRegionResponse {
region: Some(self.region.clone()),
leader: Some(self.leader.clone()),
}))
}
async fn list_regions(
&self,
_request: Request<ListRegionsRequest>,
) -> Result<Response<ListRegionsResponse>, Status> {
Ok(Response::new(ListRegionsResponse {
regions: vec![self.region.clone()],
stores: vec![self.leader.clone()],
}))
}
}
#[derive(Clone, Default)]
struct TestTsoService {
counter: Arc<AtomicU64>,
}
#[tonic::async_trait]
impl Tso for TestTsoService {
async fn get_timestamp(
&self,
request: Request<TsoRequest>,
) -> Result<Response<TsoResponse>, Status> {
let count = request.into_inner().count.max(1) as u64;
let start = self.counter.fetch_add(count, Ordering::AcqRel) + 1;
Ok(Response::new(TsoResponse {
timestamp: start,
count: count as u32,
}))
}
}
async fn start_kv_server(
service: TestKvService,
) -> Result<
(SocketAddr, oneshot::Sender<()>, tokio::task::JoinHandle<()>),
Box<dyn std::error::Error>,
> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let incoming = TcpListenerStream::new(listener);
let (tx, rx) = oneshot::channel();
let raw_service = service.clone();
let cas_service = service.clone();
let handle = tokio::spawn(async move {
Server::builder()
.add_service(KvRawServer::new(raw_service))
.add_service(KvCasServer::new(cas_service))
.serve_with_incoming_shutdown(incoming, async {
let _ = rx.await;
})
.await
.unwrap();
});
Ok((addr, tx, handle))
}
async fn start_pd_server(
region: Region,
leader: Store,
) -> Result<
(SocketAddr, oneshot::Sender<()>, tokio::task::JoinHandle<()>),
Box<dyn std::error::Error>,
> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let incoming = TcpListenerStream::new(listener);
let (tx, rx) = oneshot::channel();
let tso_service = TestTsoService::default();
let pd_service = TestPdService { region, leader };
let handle = tokio::spawn(async move {
Server::builder()
.add_service(TsoServer::new(tso_service))
.add_service(PdServer::new(pd_service))
.serve_with_incoming_shutdown(incoming, async {
let _ = rx.await;
})
.await
.unwrap();
});
Ok((addr, tx, handle))
}
#[tokio::test(flavor = "multi_thread")]
async fn test_rpc_connect() -> Result<(), Box<dyn std::error::Error>> {
let kv_service = TestKvService::default();
let (kv_addr, kv_shutdown, kv_handle) = start_kv_server(kv_service).await?;
let leader = Store {
id: 1,
addr: kv_addr.to_string(),
};
let region = Region {
id: 1,
start_key: Vec::new(),
end_key: Vec::new(),
peers: vec![1],
leader_id: 1,
};
let (pd_addr, pd_shutdown, pd_handle) = start_pd_server(region, leader).await?;
let mut client = RdbClient::connect_with_pd(kv_addr.to_string(), pd_addr.to_string()).await?;
let ts = client.get_tso().await?;
assert!(ts > 0);
client.raw_put(b"k1".to_vec(), b"v1".to_vec()).await?;
let got = client.raw_get(b"k1".to_vec()).await?;
assert_eq!(got, Some(b"v1".to_vec()));
let (ok, current, new_version) = client.cas(b"cas_key".to_vec(), b"v1".to_vec(), 0).await?;
assert!(ok);
assert_eq!(current, 0);
assert_eq!(new_version, 1);
let (ok2, current2, _) = client.cas(b"cas_key".to_vec(), b"v2".to_vec(), 0).await?;
assert!(!ok2);
assert_eq!(current2, 1);
let cas_val = client.cas_get(b"cas_key".to_vec()).await?;
assert_eq!(cas_val, Some((1, b"v1".to_vec())));
let _ = kv_shutdown.send(());
let _ = pd_shutdown.send(());
kv_handle.await?;
pd_handle.await?;
Ok(())
}