Initial commit

This commit is contained in:
Soma Nakamura 2026-02-13 17:08:17 +09:00
commit 9a5d8ca8ba
Signed by: centra
GPG key ID: 0C09689D20B25ACA
24 changed files with 10049 additions and 0 deletions

3
.gitignore vendored Normal file
View file

@ -0,0 +1,3 @@
target/
state.json
*.log

2695
Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

35
Cargo.toml Normal file
View file

@ -0,0 +1,35 @@
[package]
name = "lightscale-client"
version = "0.1.0"
edition = "2021"
[dependencies]
anyhow = "1"
base64 = "0.22"
boringtun = { version = "0.7", features = ["device"] }
clap = { version = "4", features = ["derive", "env"] }
dirs = "5"
ed25519-dalek = { version = "2", features = ["rand_core"] }
futures-util = "0.3"
hmac = "0.12"
hickory-proto = "0.24"
hex = "0.4"
ipnet = "2"
md-5 = "0.10"
rand = "0.8"
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] }
rustls = "0.23"
rtnetlink = "0.20"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
sha2 = "0.10"
sha1 = "0.10"
socket2 = "0.5"
time = { version = "0.3", features = ["serde", "formatting"] }
tokio = { version = "1", features = ["io-util", "macros", "rt-multi-thread", "signal"] }
tokio-rustls = "0.26"
url = "2"
webpki-roots = "0.26"
wireguard-control = "1.7"
x25519-dalek = { version = "2", features = ["static_secrets"] }
netlink-packet-route = "0.28"

287
README.md Normal file
View file

@ -0,0 +1,287 @@
# lightscale-client
Minimal control-plane client for Lightscale. It registers nodes, sends heartbeats, fetches
netmaps, and can manage the WireGuard data plane (kernel or userspace) with basic NAT traversal.
This client already uses profile-scoped state files so multiple networks can be supported later by
running separate profiles.
## Configure a profile
```sh
cargo run -- --profile default init http://127.0.0.1:8080
```
Multiple control URLs (failover):
```sh
cargo run -- --profile default init http://10.0.0.1:8080,http://10.0.0.1:8081
```
## Register a node
```sh
cargo run -- --profile default register <token> --node-name laptop
```
Register a node with an auth URL flow:
```sh
cargo run -- --profile default register-url <network_id> --node-name laptop
```
The command prints a one-time approval URL. Open it in a browser (or curl it) to approve the node.
## Admin actions
Set an admin token when the control plane is protected (CLI flag or env var):
```sh
export LIGHTSCALE_ADMIN_TOKEN=<token>
```
List nodes in a network (use `--pending` to show only unapproved nodes):
```sh
cargo run -- --profile default admin nodes <network_id> --pending
```
Update a node's name or tags (admin):
```sh
cargo run -- --profile default admin node update <node_id> --name laptop --tags dev,lab
```
Clear tags:
```sh
cargo run -- --profile default admin node update <node_id> --clear-tags
```
Approve a node by ID:
```sh
cargo run -- --profile default admin approve <node_id>
```
Create an enrollment token:
```sh
cargo run -- --profile default admin token create <network_id> --ttl-seconds 3600 --uses 1 --tags lab
```
Revoke an enrollment token:
```sh
cargo run -- --profile default admin token revoke <token>
```
## Heartbeat
```sh
cargo run -- --profile default heartbeat \
--endpoint 203.0.113.1:51820 \
--route 192.168.10.0/24
```
Optionally include your WireGuard listen port so the server can add the observed
public endpoint from the heartbeat connection:
```sh
cargo run -- --profile default heartbeat --listen-port 51820
```
Use STUN to discover a public endpoint (best effort):
```sh
cargo run -- --profile default heartbeat --stun --stun-server stun.l.google.com:19302
```
Advertise exit node routes:
```sh
cargo run -- --profile default heartbeat --exit-node
```
## Fetch netmap
```sh
cargo run -- --profile default netmap
```
## Show status
```sh
cargo run -- --profile default status
```
Include WireGuard peer info (handshake age + endpoint):
```sh
cargo run -- --profile default status --wg
```
## Configure WireGuard (Linux)
Bring up an interface using the latest netmap:
```sh
sudo cargo run -- --profile default wg-up --listen-port 51820
```
Use boringtun (userspace WireGuard) instead of the kernel module:
```sh
sudo cargo run -- --profile default wg-up --listen-port 51820 --backend boringtun
```
This runs the userspace tunnel inside the client process. Keep the command
running (or use `agent`) to keep the tunnel alive.
Apply advertised subnet/exit routes at the same time:
```sh
sudo cargo run -- --profile default wg-up --listen-port 51820 --apply-routes --accept-exit-node
```
Optionally probe peers to trigger NAT traversal (UDP probe, no ICMP):
```sh
sudo cargo run -- --profile default wg-up --listen-port 51820 --probe-peers
```
Conflicting routes are skipped by default; use `--allow-route-conflicts` to force them.
Select a specific exit node by ID or name:
```sh
sudo cargo run -- --profile default wg-up --listen-port 51820 --apply-routes --accept-exit-node \
--exit-node-id <peer_id>
```
Remove the interface:
```sh
sudo cargo run -- --profile default wg-down
```
If you used the boringtun backend, stop the process that created the tunnel
(for example `Ctrl+C` in the foreground or stopping the agent). The command
below attempts to remove the interface if needed.
```sh
sudo cargo run -- --profile default wg-down --backend boringtun
```
## Run the agent loop
Keep WireGuard and routes updated using long-polling + periodic heartbeats:
```sh
sudo cargo run -- --profile default agent --listen-port 51820 --apply-routes --accept-exit-node \
--heartbeat-interval 30 --longpoll-timeout 30
```
Tune endpoint rotation (stale seconds + max rotations before relay fallback):
```sh
sudo cargo run -- --profile default agent --listen-port 51820 \
--endpoint-stale-after 15 --endpoint-max-rotations 2
```
Use boringtun backend in the agent:
```sh
sudo cargo run -- --profile default agent --listen-port 51820 --backend boringtun
```
Enable STUN discovery in the agent:
```sh
sudo cargo run -- --profile default agent --listen-port 51820 --stun \
--stun-server stun.l.google.com:19302
```
Enable stream relay signaling (peer probe via relay):
```sh
sudo cargo run -- --profile default agent --listen-port 51820 --stream-relay
```
With `--stream-relay`, the agent also maintains local relay tunnels that can be
used as a fallback when direct endpoints stop handshaking.
Probe peers when netmap updates arrive (UDP probe to endpoints, no ICMP):
```sh
sudo cargo run -- --profile default agent --listen-port 51820 --probe-peers
```
## Enable subnet/exit routing (Linux)
Configure IP forwarding and (optionally) SNAT for a subnet router or exit node.
This uses nftables via `libmnl`/`libnftnl` (the Nix dev shell installs them):
```sh
sudo cargo run -- --profile default router enable --interface ls-default --out-interface eth0
```
Disable SNAT to require return routes on the LAN:
```sh
sudo cargo run -- --profile default router enable --interface ls-default --out-interface eth0 --no-snat
```
Remove forwarding/NAT rules:
```sh
sudo cargo run -- --profile default router disable --interface ls-default --out-interface eth0
```
## DNS and relay info
Export host-style DNS entries:
```sh
cargo run -- --profile default dns
```
Export DNS info as JSON (debug output):
```sh
cargo run -- --profile default dns --format json --output /tmp/lightscale-dns.json
```
Show relay configuration (STUN/TURN/stream relay/UDP relay):
```sh
cargo run -- --profile default relay
```
## UDP relay (best effort)
Send a test message via the UDP relay:
```sh
cargo run -- --profile default relay-udp send <peer-id> "hello"
```
Listen for relay messages:
```sh
cargo run -- --profile default relay-udp listen
```
## Stream relay (best effort)
Send a test message via the stream relay:
```sh
cargo run -- --profile default relay-stream send <peer-id> "hello"
```
Listen for relay messages:
```sh
cargo run -- --profile default relay-stream listen
```

61
flake.lock generated Normal file
View file

@ -0,0 +1,61 @@
{
"nodes": {
"flake-utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1768564909,
"narHash": "sha256-Kell/SpJYVkHWMvnhqJz/8DqQg2b6PguxVWOuadbHCc=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "e4bae1bd10c9c57b2cf517953ab70060a828ee6f",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"root": {
"inputs": {
"flake-utils": "flake-utils",
"nixpkgs": "nixpkgs"
}
},
"systems": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
}
},
"root": "root",
"version": 7
}

28
flake.nix Normal file
View file

@ -0,0 +1,28 @@
{
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
flake-utils.url = "github:numtide/flake-utils";
};
outputs = { self, nixpkgs, flake-utils }:
flake-utils.lib.eachDefaultSystem (system:
let
pkgs = import nixpkgs { inherit system; };
in
{
devShells.default = pkgs.mkShell {
buildInputs = [
pkgs.rustc
pkgs.cargo
pkgs.rustfmt
pkgs.clippy
pkgs.rust-analyzer
pkgs.iproute2
pkgs.iputils
pkgs.libmnl
pkgs.libnftnl
pkgs.pkg-config
];
};
});
}

60
src/config.rs Normal file
View file

@ -0,0 +1,60 @@
use anyhow::Result;
use serde::{Deserialize, Deserializer, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
#[derive(Clone, Serialize, Deserialize, Default)]
pub struct ClientConfig {
pub profiles: HashMap<String, ProfileConfig>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct ProfileConfig {
#[serde(default, deserialize_with = "deserialize_control_urls", alias = "control_url")]
pub control_urls: Vec<String>,
#[serde(default)]
pub tls_pinned_sha256: Option<String>,
}
fn deserialize_control_urls<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum ControlUrls {
One(String),
Many(Vec<String>),
}
let raw = Option::<ControlUrls>::deserialize(deserializer)?;
let mut urls = match raw {
Some(ControlUrls::One(url)) => vec![url],
Some(ControlUrls::Many(urls)) => urls,
None => Vec::new(),
};
urls.retain(|url| !url.trim().is_empty());
Ok(urls)
}
pub fn default_config_path() -> Option<PathBuf> {
dirs::config_dir().map(|dir| dir.join("lightscale").join("config.json"))
}
pub fn load_config(path: &Path) -> Result<ClientConfig> {
match std::fs::read_to_string(path) {
Ok(contents) => Ok(serde_json::from_str(&contents)?),
Err(_) => Ok(ClientConfig::default()),
}
}
pub fn save_config(path: &Path, config: &ClientConfig) -> Result<()> {
if let Some(parent) = path.parent() {
if !parent.as_os_str().is_empty() {
std::fs::create_dir_all(parent)?;
}
}
let json = serde_json::to_string_pretty(config)?;
std::fs::write(path, json)?;
Ok(())
}

573
src/control.rs Normal file
View file

@ -0,0 +1,573 @@
use crate::model::{
AclPolicy, AdminNodesResponse, ApproveNodeResponse, AuditLogResponse, CreateTokenRequest,
CreateTokenResponse, EnrollmentToken, HeartbeatRequest, HeartbeatResponse, KeyHistoryResponse,
KeyPolicyResponse, KeyRotationPolicy, KeyRotationRequest, KeyRotationResponse, NetMap,
RegisterRequest, RegisterResponse, RegisterUrlRequest, RegisterUrlResponse, RevokeNodeResponse,
UpdateAclRequest, UpdateAclResponse, UpdateNodeRequest, UpdateNodeResponse,
};
use anyhow::{anyhow, Context, Result};
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::client::WebPkiServerVerifier;
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use rustls::{RootCertStore, SignatureScheme};
use sha2::{Digest, Sha256};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
pub struct ControlClient {
base_urls: Vec<String>,
client: reqwest::Client,
next_index: AtomicUsize,
node_token: Option<String>,
admin_token: Option<String>,
}
impl ControlClient {
pub fn new(
base_urls: Vec<String>,
tls_pin: Option<String>,
node_token: Option<String>,
admin_token: Option<String>,
) -> Result<Self> {
let client = build_http_client(tls_pin)?;
let base_urls = normalize_base_urls(base_urls);
if base_urls.is_empty() {
return Err(anyhow!("no control URL configured"));
}
Ok(Self {
base_urls,
client,
next_index: AtomicUsize::new(0),
node_token,
admin_token,
})
}
async fn send_with_failover<F>(&self, build: F) -> Result<reqwest::Response>
where
F: Fn(&reqwest::Client, &str) -> reqwest::RequestBuilder,
{
let total = self.base_urls.len();
let start = self.next_index.load(Ordering::Relaxed) % total;
let mut last_err: Option<anyhow::Error> = None;
for offset in 0..total {
let index = (start + offset) % total;
let base = &self.base_urls[index];
let response = build(&self.client, base).send().await;
match response {
Ok(resp) => {
if resp.status().is_server_error() {
last_err = Some(anyhow!(
"control {} returned {}",
base,
resp.status()
));
continue;
}
self.next_index.store(index, Ordering::Relaxed);
return Ok(resp);
}
Err(err) => {
if should_retry(&err) {
last_err = Some(anyhow!("control {} request failed: {}", base, err));
continue;
}
return Err(anyhow!(err).context(format!(
"control {} request failed",
base
)));
}
}
}
Err(last_err.unwrap_or_else(|| anyhow!("no control servers available")))
}
fn endpoint_at(base: &str, path: &str) -> String {
let base = base.trim_end_matches('/');
format!("{}{}", base, path)
}
fn node_auth(&self) -> Option<&str> {
self.node_token.as_deref()
}
fn admin_auth(&self) -> Option<&str> {
self.admin_token.as_deref()
}
fn node_or_admin_auth(&self) -> Option<&str> {
self.node_token
.as_deref()
.or_else(|| self.admin_token.as_deref())
}
pub async fn register(&self, request: RegisterRequest) -> Result<RegisterResponse> {
let response = self
.send_with_failover(|client, base| {
client
.post(Self::endpoint_at(base, "/v1/register"))
.json(&request)
})
.await?
.error_for_status()
.context("register request failed")?;
Ok(response.json().await?)
}
pub async fn register_url(&self, request: RegisterUrlRequest) -> Result<RegisterUrlResponse> {
let response = self
.send_with_failover(|client, base| {
client
.post(Self::endpoint_at(base, "/v1/register-url"))
.json(&request)
})
.await?
.error_for_status()
.context("register-url request failed")?;
Ok(response.json().await?)
}
pub async fn create_token(
&self,
network_id: &str,
request: CreateTokenRequest,
) -> Result<CreateTokenResponse> {
let response = self
.send_with_failover(|client, base| {
with_bearer(
client
.post(Self::endpoint_at(
base,
&format!("/v1/networks/{}/tokens", network_id),
))
.json(&request),
self.admin_auth(),
)
})
.await?
.error_for_status()
.context("create token request failed")?;
Ok(response.json().await?)
}
pub async fn revoke_token(&self, token_id: &str) -> Result<EnrollmentToken> {
let response = self
.send_with_failover(|client, base| {
with_bearer(
client.post(Self::endpoint_at(
base,
&format!("/v1/tokens/{}/revoke", token_id),
)),
self.admin_auth(),
)
})
.await?
.error_for_status()
.context("revoke token request failed")?;
Ok(response.json().await?)
}
pub async fn approve_node(&self, node_id: &str) -> Result<ApproveNodeResponse> {
let response = self
.send_with_failover(|client, base| {
with_bearer(
client.post(Self::endpoint_at(
base,
&format!("/v1/admin/nodes/{}/approve", node_id),
)),
self.admin_auth(),
)
})
.await?
.error_for_status()
.context("approve node request failed")?;
Ok(response.json().await?)
}
pub async fn admin_nodes(&self, network_id: &str) -> Result<AdminNodesResponse> {
let response = self
.send_with_failover(|client, base| {
with_bearer(
client.get(Self::endpoint_at(
base,
&format!("/v1/admin/networks/{}/nodes", network_id),
)),
self.admin_auth(),
)
})
.await?
.error_for_status()
.context("admin nodes request failed")?;
Ok(response.json().await?)
}
pub async fn update_node(
&self,
node_id: &str,
request: UpdateNodeRequest,
) -> Result<UpdateNodeResponse> {
let response = self
.send_with_failover(|client, base| {
with_bearer(
client
.put(Self::endpoint_at(
base,
&format!("/v1/admin/nodes/{}", node_id),
))
.json(&request),
self.admin_auth(),
)
})
.await?
.error_for_status()
.context("update node request failed")?;
Ok(response.json().await?)
}
pub async fn get_acl(&self, network_id: &str) -> Result<AclPolicy> {
let response = self
.send_with_failover(|client, base| {
with_bearer(
client.get(Self::endpoint_at(
base,
&format!("/v1/networks/{}/acl", network_id),
)),
self.admin_auth(),
)
})
.await?
.error_for_status()
.context("acl policy request failed")?;
Ok(response.json().await?)
}
pub async fn update_acl(
&self,
network_id: &str,
request: UpdateAclRequest,
) -> Result<UpdateAclResponse> {
let response = self
.send_with_failover(|client, base| {
with_bearer(
client
.put(Self::endpoint_at(
base,
&format!("/v1/networks/{}/acl", network_id),
))
.json(&request),
self.admin_auth(),
)
})
.await?
.error_for_status()
.context("acl policy update failed")?;
Ok(response.json().await?)
}
pub async fn get_key_policy(&self, network_id: &str) -> Result<KeyPolicyResponse> {
let response = self
.send_with_failover(|client, base| {
with_bearer(
client.get(Self::endpoint_at(
base,
&format!("/v1/networks/{}/key-policy", network_id),
)),
self.admin_auth(),
)
})
.await?
.error_for_status()
.context("key policy request failed")?;
Ok(response.json().await?)
}
pub async fn update_key_policy(
&self,
network_id: &str,
request: KeyRotationPolicy,
) -> Result<KeyPolicyResponse> {
let response = self
.send_with_failover(|client, base| {
with_bearer(
client
.put(Self::endpoint_at(
base,
&format!("/v1/networks/{}/key-policy", network_id),
))
.json(&request),
self.admin_auth(),
)
})
.await?
.error_for_status()
.context("key policy update failed")?;
Ok(response.json().await?)
}
pub async fn rotate_keys(
&self,
node_id: &str,
request: KeyRotationRequest,
) -> Result<KeyRotationResponse> {
let response = self
.send_with_failover(|client, base| {
with_bearer(
client
.post(Self::endpoint_at(
base,
&format!("/v1/nodes/{}/rotate-keys", node_id),
))
.json(&request),
self.node_or_admin_auth(),
)
})
.await?
.error_for_status()
.context("key rotation failed")?;
Ok(response.json().await?)
}
pub async fn revoke_node(&self, node_id: &str) -> Result<RevokeNodeResponse> {
let response = self
.send_with_failover(|client, base| {
with_bearer(
client.post(Self::endpoint_at(
base,
&format!("/v1/nodes/{}/revoke", node_id),
)),
self.admin_auth(),
)
})
.await?
.error_for_status()
.context("revoke node failed")?;
Ok(response.json().await?)
}
pub async fn node_keys(&self, node_id: &str) -> Result<KeyHistoryResponse> {
let response = self
.send_with_failover(|client, base| {
with_bearer(
client.get(Self::endpoint_at(
base,
&format!("/v1/nodes/{}/keys", node_id),
)),
self.node_or_admin_auth(),
)
})
.await?
.error_for_status()
.context("key history request failed")?;
Ok(response.json().await?)
}
pub async fn audit_log(
&self,
network_id: Option<&str>,
node_id: Option<&str>,
limit: Option<usize>,
) -> Result<AuditLogResponse> {
let mut params = Vec::new();
if let Some(network_id) = network_id {
params.push(("network_id", network_id.to_string()));
}
if let Some(node_id) = node_id {
params.push(("node_id", node_id.to_string()));
}
if let Some(limit) = limit {
params.push(("limit", limit.to_string()));
}
let response = self
.send_with_failover(|client, base| {
with_bearer(
client
.get(Self::endpoint_at(base, "/v1/audit"))
.query(&params),
self.admin_auth(),
)
})
.await?
.error_for_status()
.context("audit log request failed")?;
Ok(response.json().await?)
}
pub async fn heartbeat(&self, request: HeartbeatRequest) -> Result<HeartbeatResponse> {
let response = self
.send_with_failover(|client, base| {
with_bearer(
client
.post(Self::endpoint_at(base, "/v1/heartbeat"))
.json(&request),
self.node_or_admin_auth(),
)
})
.await?
.error_for_status()
.context("heartbeat request failed")?;
Ok(response.json().await?)
}
pub async fn netmap(&self, node_id: &str) -> Result<NetMap> {
let response = self
.send_with_failover(|client, base| {
with_bearer(
client.get(Self::endpoint_at(
base,
&format!("/v1/netmap/{}", node_id),
)),
self.node_or_admin_auth(),
)
})
.await?
.error_for_status()
.context("netmap request failed")?;
Ok(response.json().await?)
}
pub async fn netmap_longpoll(
&self,
node_id: &str,
since: u64,
timeout_seconds: u64,
) -> Result<NetMap> {
let response = self
.send_with_failover(|client, base| {
with_bearer(
client
.get(Self::endpoint_at(
base,
&format!("/v1/netmap/{}/longpoll", node_id),
))
.query(&[
("since", since.to_string()),
("timeout_seconds", timeout_seconds.to_string()),
]),
self.node_or_admin_auth(),
)
})
.await?
.error_for_status()
.context("netmap longpoll request failed")?;
Ok(response.json().await?)
}
}
fn normalize_base_urls(urls: Vec<String>) -> Vec<String> {
let mut unique = Vec::new();
for url in urls {
let trimmed = url.trim().to_string();
if trimmed.is_empty() {
continue;
}
if !unique.contains(&trimmed) {
unique.push(trimmed);
}
}
unique
}
fn should_retry(err: &reqwest::Error) -> bool {
err.is_connect() || err.is_timeout() || err.is_request()
}
fn with_bearer(
request: reqwest::RequestBuilder,
token: Option<&str>,
) -> reqwest::RequestBuilder {
if let Some(token) = token {
request.bearer_auth(token)
} else {
request
}
}
#[derive(Debug)]
struct PinnedServerCertVerifier {
inner: Arc<WebPkiServerVerifier>,
pin: Vec<u8>,
}
impl ServerCertVerifier for PinnedServerCertVerifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
intermediates: &[CertificateDer<'_>],
server_name: &ServerName<'_>,
ocsp_response: &[u8],
now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
let verified = self
.inner
.verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now)?;
let mut hasher = Sha256::new();
hasher.update(end_entity.as_ref());
let digest = hasher.finalize();
if digest.as_slice() != self.pin.as_slice() {
return Err(rustls::Error::General("tls pin mismatch".to_string()));
}
Ok(verified)
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.inner.supported_verify_schemes()
}
}
fn build_http_client(tls_pin: Option<String>) -> Result<reqwest::Client> {
if let Some(pin) = tls_pin {
let expected = decode_pin(&pin)?;
let mut roots = RootCertStore::empty();
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let verifier = WebPkiServerVerifier::builder(Arc::new(roots.clone()))
.build()
.map_err(|err| anyhow!("failed to build tls verifier: {}", err))?;
let config = rustls::ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
let mut config = config;
config
.dangerous()
.set_certificate_verifier(Arc::new(PinnedServerCertVerifier {
inner: verifier,
pin: expected,
}));
Ok(reqwest::Client::builder()
.use_preconfigured_tls(config)
.build()?)
} else {
Ok(reqwest::Client::new())
}
}
fn decode_pin(pin: &str) -> Result<Vec<u8>> {
let normalized: String = pin
.chars()
.filter(|ch| !ch.is_whitespace() && *ch != ':')
.collect();
let bytes = hex::decode(normalized).map_err(|_| anyhow!("invalid tls pin hex"))?;
if bytes.len() != 32 {
return Err(anyhow!("tls pin must be 32 bytes (sha256)"));
}
Ok(bytes)
}

148
src/dns_server.rs Normal file
View file

@ -0,0 +1,148 @@
use crate::model::NetMap;
use anyhow::{anyhow, Context, Result};
use hickory_proto::op::{Message, MessageType, ResponseCode};
use hickory_proto::rr::rdata::{A, AAAA};
use hickory_proto::rr::{Name, RData, Record, RecordType};
use std::net::{IpAddr, SocketAddr};
use std::sync::{Arc, Mutex};
use tokio::net::UdpSocket;
const DNS_TTL_SECONDS: u32 = 30;
pub fn spawn(addr: SocketAddr, netmap: NetMap) -> Result<Arc<Mutex<NetMap>>> {
let state = Arc::new(Mutex::new(netmap));
let state_task = Arc::clone(&state);
tokio::spawn(async move {
if let Err(err) = serve(addr, state_task).await {
eprintln!("dns server stopped: {}", err);
}
});
Ok(state)
}
pub async fn serve(addr: SocketAddr, state: Arc<Mutex<NetMap>>) -> Result<()> {
let socket = UdpSocket::bind(addr)
.await
.with_context(|| format!("dns listen {} failed", addr))?;
let mut buf = vec![0u8; 512];
loop {
let (len, peer) = socket.recv_from(&mut buf).await?;
let request = match Message::from_vec(&buf[..len]) {
Ok(msg) => msg,
Err(_) => continue,
};
let response = build_response(&request, &state)?;
let out = response.to_vec()?;
let _ = socket.send_to(&out, peer).await;
}
}
pub fn apply_resolver(interface: &str, domain: &str, server: IpAddr) -> Result<()> {
let domain = domain.trim_end_matches('.');
let routed_domain = format!("~{}", domain);
run_resolvectl(&["dns", interface, &server.to_string()])?;
run_resolvectl(&["domain", interface, &routed_domain])?;
Ok(())
}
fn run_resolvectl(args: &[&str]) -> Result<()> {
let output = std::process::Command::new("resolvectl").args(args).output();
let output = match output {
Ok(output) => output,
Err(err) => return Err(anyhow!("resolvectl failed: {}", err)),
};
if output.status.success() {
return Ok(());
}
let stderr = String::from_utf8_lossy(&output.stderr);
Err(anyhow!("resolvectl failed: {}", stderr.trim()))
}
fn build_response(request: &Message, state: &Arc<Mutex<NetMap>>) -> Result<Message> {
let mut response = Message::new();
response.set_id(request.id());
response.set_message_type(MessageType::Response);
response.set_op_code(request.op_code());
response.set_recursion_desired(request.recursion_desired());
response.set_recursion_available(false);
let netmap = state.lock().map_err(|_| anyhow!("dns state poisoned"))?;
let domain = normalize_name(&netmap.network.dns_domain);
let mut answered = false;
let mut any_within_domain = false;
for query in request.queries() {
response.add_query(query.clone());
let name = normalize_name(&query.name().to_ascii());
let within_domain =
name == domain || name.ends_with(&format!(".{}", domain));
if within_domain {
any_within_domain = true;
}
let Some(addrs) = lookup_name(&netmap, &name) else {
continue;
};
for addr in addrs {
match (query.query_type(), addr) {
(RecordType::A, IpAddr::V4(_)) => {
response.add_answer(build_record(query.name(), addr));
answered = true;
}
(RecordType::AAAA, IpAddr::V6(_)) => {
response.add_answer(build_record(query.name(), addr));
answered = true;
}
(RecordType::ANY, IpAddr::V4(_)) => {
response.add_answer(build_record(query.name(), addr));
answered = true;
}
(RecordType::ANY, IpAddr::V6(_)) => {
response.add_answer(build_record(query.name(), addr));
answered = true;
}
_ => {}
}
}
}
let response_code = if answered {
ResponseCode::NoError
} else if any_within_domain {
ResponseCode::NXDomain
} else {
ResponseCode::Refused
};
response.set_response_code(response_code);
response.set_authoritative(true);
Ok(response)
}
fn build_record(name: &Name, addr: IpAddr) -> Record {
let rdata = match addr {
IpAddr::V4(v4) => RData::A(A(v4)),
IpAddr::V6(v6) => RData::AAAA(AAAA(v6)),
};
Record::from_rdata(name.clone(), DNS_TTL_SECONDS, rdata)
}
fn lookup_name(netmap: &NetMap, name: &str) -> Option<Vec<IpAddr>> {
let node_name = normalize_name(&netmap.node.dns_name);
if name == node_name {
return Some(vec![
netmap.node.ipv4.parse().ok()?,
netmap.node.ipv6.parse().ok()?,
]);
}
for peer in &netmap.peers {
let peer_name = normalize_name(&peer.dns_name);
if name == peer_name {
return Some(vec![
peer.ipv4.parse().ok()?,
peer.ipv6.parse().ok()?,
]);
}
}
None
}
fn normalize_name(name: &str) -> String {
name.trim_end_matches('.').to_lowercase()
}

272
src/firewall.rs Normal file
View file

@ -0,0 +1,272 @@
use anyhow::{anyhow, Result};
#[cfg(target_os = "linux")]
mod imp {
use super::*;
use ipnet::IpNet;
use std::process::Command;
const FILTER_TABLE: &str = "lightscale";
const FILTER_CHAIN: &str = "ls-forward";
const NAT_TABLE: &str = "lightscale-nat";
const NAT_CHAIN: &str = "ls-postrouting";
const MAP_PREROUTING_CHAIN: &str = "ls-map-prerouting";
const MAP_POSTROUTING_CHAIN: &str = "ls-map-postrouting";
pub fn reset_tables() -> Result<()> {
if run_nft(&["list", "table", "inet", FILTER_TABLE]).is_ok() {
run_nft(&["delete", "table", "inet", FILTER_TABLE])?;
}
if run_nft(&["list", "table", "ip", NAT_TABLE]).is_ok() {
run_nft(&["delete", "table", "ip", NAT_TABLE])?;
}
Ok(())
}
pub fn apply_forwarding_rules(wg_interface: &str, out_interface: &str) -> Result<()> {
ensure_filter_table()?;
ensure_filter_chain()?;
run_nft(&["flush", "chain", "inet", FILTER_TABLE, FILTER_CHAIN])?;
run_nft(&[
"add",
"rule",
"inet",
FILTER_TABLE,
FILTER_CHAIN,
"iifname",
wg_interface,
"oifname",
out_interface,
"accept",
])?;
run_nft(&[
"add",
"rule",
"inet",
FILTER_TABLE,
FILTER_CHAIN,
"iifname",
out_interface,
"oifname",
wg_interface,
"ct",
"state",
"established,related",
"accept",
])?;
Ok(())
}
pub fn apply_snat(out_interface: &str) -> Result<()> {
ensure_nat_table()?;
ensure_nat_chain(NAT_CHAIN, "postrouting", "100")?;
run_nft(&["flush", "chain", "ip", NAT_TABLE, NAT_CHAIN])?;
run_nft(&[
"add",
"rule",
"ip",
NAT_TABLE,
NAT_CHAIN,
"oifname",
out_interface,
"masquerade",
])?;
Ok(())
}
pub fn apply_netmap(
wg_interface: &str,
_out_interface: &str,
maps: &[(IpNet, IpNet)],
) -> Result<()> {
if maps.is_empty() {
return Ok(());
}
ensure_nat_table()?;
ensure_nat_chain(MAP_PREROUTING_CHAIN, "prerouting", "-100")?;
ensure_nat_chain(MAP_POSTROUTING_CHAIN, "postrouting", "90")?;
run_nft(&["flush", "chain", "ip", NAT_TABLE, MAP_PREROUTING_CHAIN])?;
run_nft(&["flush", "chain", "ip", NAT_TABLE, MAP_POSTROUTING_CHAIN])?;
for (real, mapped) in maps {
let (real, mapped) = match (real, mapped) {
(IpNet::V4(real), IpNet::V4(mapped)) => (real, mapped),
_ => {
return Err(anyhow!(
"netmap only supports IPv4 prefixes in this build"
))
}
};
let prefix_len = mapped.prefix_len();
let host_mask = ipv4_host_mask(prefix_len);
let mapped_base = mapped.network();
let real_base = real.network();
run_nft(&[
"add",
"rule",
"ip",
NAT_TABLE,
MAP_PREROUTING_CHAIN,
"iifname",
wg_interface,
"ip",
"daddr",
&mapped.to_string(),
"dnat",
"to",
"ip",
"daddr",
"&",
&host_mask.to_string(),
"|",
&real_base.to_string(),
])?;
run_nft(&[
"add",
"rule",
"ip",
NAT_TABLE,
MAP_POSTROUTING_CHAIN,
"oifname",
wg_interface,
"ip",
"saddr",
&real.to_string(),
"snat",
"to",
"ip",
"saddr",
"&",
&host_mask.to_string(),
"|",
&mapped_base.to_string(),
])?;
}
Ok(())
}
fn ensure_filter_table() -> Result<()> {
if run_nft(&["list", "table", "inet", FILTER_TABLE]).is_ok() {
return Ok(());
}
run_nft(&["add", "table", "inet", FILTER_TABLE])?;
Ok(())
}
fn ensure_filter_chain() -> Result<()> {
if run_nft(&["list", "chain", "inet", FILTER_TABLE, FILTER_CHAIN]).is_ok() {
return Ok(());
}
run_nft(&[
"add",
"chain",
"inet",
FILTER_TABLE,
FILTER_CHAIN,
"{",
"type",
"filter",
"hook",
"forward",
"priority",
"10",
";",
"policy",
"drop",
";",
"}",
])?;
Ok(())
}
fn ensure_nat_table() -> Result<()> {
if run_nft(&["list", "table", "ip", NAT_TABLE]).is_ok() {
return Ok(());
}
run_nft(&["add", "table", "ip", NAT_TABLE])?;
Ok(())
}
fn ensure_nat_chain(name: &str, hook: &str, priority: &str) -> Result<()> {
if run_nft(&["list", "chain", "ip", NAT_TABLE, name]).is_ok() {
return Ok(());
}
run_nft(&[
"add",
"chain",
"ip",
NAT_TABLE,
name,
"{",
"type",
"nat",
"hook",
hook,
"priority",
priority,
";",
"policy",
"accept",
";",
"}",
])?;
Ok(())
}
fn ipv4_host_mask(prefix_len: u8) -> std::net::Ipv4Addr {
if prefix_len >= 32 {
return std::net::Ipv4Addr::from(0);
}
let mask = if prefix_len == 0 {
u32::MAX
} else {
u32::MAX >> prefix_len
};
std::net::Ipv4Addr::from(mask)
}
fn run_nft(args: &[&str]) -> Result<()> {
let output = Command::new("nft").args(args).output();
let output = match output {
Ok(output) => output,
Err(err) => return Err(anyhow!("nft command failed: {}", err)),
};
if output.status.success() {
return Ok(());
}
let stderr = String::from_utf8_lossy(&output.stderr);
Err(anyhow!(
"nft command failed: {}",
stderr.trim().to_string()
))
}
}
#[cfg(target_os = "linux")]
pub use imp::{apply_forwarding_rules, apply_netmap, apply_snat, reset_tables};
#[cfg(not(target_os = "linux"))]
mod imp {
use super::*;
pub fn reset_tables() -> Result<()> {
Err(anyhow!("router firewall is only supported on linux"))
}
pub fn apply_forwarding_rules(_wg_interface: &str, _out_interface: &str) -> Result<()> {
Err(anyhow!("router firewall is only supported on linux"))
}
pub fn apply_snat(_out_interface: &str) -> Result<()> {
Err(anyhow!("router firewall is only supported on linux"))
}
pub fn apply_netmap(
_wg_interface: &str,
_out_interface: &str,
_maps: &[(ipnet::IpNet, ipnet::IpNet)],
) -> Result<()> {
Err(anyhow!("router firewall is only supported on linux"))
}
}
#[cfg(not(target_os = "linux"))]
pub use imp::{apply_forwarding_rules, apply_netmap, apply_snat, reset_tables};

28
src/keys.rs Normal file
View file

@ -0,0 +1,28 @@
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use ed25519_dalek::SigningKey;
use rand::rngs::OsRng;
use x25519_dalek::{PublicKey, StaticSecret};
pub struct KeyPair {
pub private_key: String,
pub public_key: String,
}
pub fn generate_machine_keys() -> KeyPair {
let signing = SigningKey::generate(&mut OsRng);
let verifying = signing.verifying_key();
KeyPair {
private_key: STANDARD.encode(signing.to_bytes()),
public_key: STANDARD.encode(verifying.to_bytes()),
}
}
pub fn generate_wg_keys() -> KeyPair {
let secret = StaticSecret::random_from_rng(&mut OsRng);
let public = PublicKey::from(&secret);
KeyPair {
private_key: STANDARD.encode(secret.to_bytes()),
public_key: STANDARD.encode(public.to_bytes()),
}
}

108
src/l2_relay.rs Normal file
View file

@ -0,0 +1,108 @@
use crate::model::NetMap;
use anyhow::{anyhow, Context, Result};
use socket2::{Domain, Protocol, Socket, Type};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::{Arc, Mutex};
use tokio::net::UdpSocket;
const MDNS_GROUP: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
const MDNS_PORT: u16 = 5353;
const SSDP_GROUP: Ipv4Addr = Ipv4Addr::new(239, 255, 255, 250);
const SSDP_PORT: u16 = 1900;
const RELAY_OFFSET: u16 = 10000;
pub fn spawn(wg_ipv4: Ipv4Addr, netmap: NetMap) -> Result<Arc<Mutex<NetMap>>> {
let state = Arc::new(Mutex::new(netmap));
let mdns_state = Arc::clone(&state);
tokio::spawn(async move {
if let Err(err) = relay_group(MDNS_GROUP, MDNS_PORT, wg_ipv4, mdns_state).await {
eprintln!("l2 relay mdns stopped: {}", err);
}
});
let ssdp_state = Arc::clone(&state);
tokio::spawn(async move {
if let Err(err) = relay_group(SSDP_GROUP, SSDP_PORT, wg_ipv4, ssdp_state).await {
eprintln!("l2 relay ssdp stopped: {}", err);
}
});
Ok(state)
}
async fn relay_group(
group: Ipv4Addr,
port: u16,
wg_ipv4: Ipv4Addr,
state: Arc<Mutex<NetMap>>,
) -> Result<()> {
let relay_port = port.saturating_add(RELAY_OFFSET);
let local = build_multicast_socket(port, group, wg_ipv4)?;
let relay = UdpSocket::bind(SocketAddr::new(
IpAddr::V4(Ipv4Addr::UNSPECIFIED),
relay_port,
))
.await
.with_context(|| format!("l2 relay bind {} failed", relay_port))?;
let mut buf_local = vec![0u8; 2048];
let mut buf_relay = vec![0u8; 2048];
loop {
tokio::select! {
recv = local.recv_from(&mut buf_local) => {
let (len, src) = recv?;
if src.port() == relay_port {
continue;
}
let peers = peers_from_state(&state, wg_ipv4);
for peer in peers {
let target = SocketAddr::new(IpAddr::V4(peer), relay_port);
let _ = relay.send_to(&buf_local[..len], target).await;
}
}
recv = relay.recv_from(&mut buf_relay) => {
let (len, _) = recv?;
let target = SocketAddr::new(IpAddr::V4(group), port);
let _ = local.send_to(&buf_relay[..len], target).await;
}
}
}
}
fn peers_from_state(state: &Arc<Mutex<NetMap>>, self_ip: Ipv4Addr) -> Vec<Ipv4Addr> {
let guard = match state.lock() {
Ok(guard) => guard,
Err(_) => return Vec::new(),
};
guard
.peers
.iter()
.filter_map(|peer| peer.ipv4.parse().ok())
.filter(|ip| *ip != self_ip)
.collect()
}
fn build_multicast_socket(port: u16, group: Ipv4Addr, iface: Ipv4Addr) -> Result<UdpSocket> {
let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))
.context("l2 relay socket create failed")?;
socket
.set_reuse_address(true)
.context("l2 relay reuseaddr failed")?;
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port);
socket.bind(&addr.into()).context("l2 relay bind failed")?;
socket
.join_multicast_v4(&group, &iface)
.context("l2 relay multicast join failed")?;
socket
.set_multicast_loop_v4(true)
.context("l2 relay multicast loop failed")?;
socket
.set_nonblocking(true)
.context("l2 relay socket nonblocking failed")?;
let std_socket: std::net::UdpSocket = socket.into();
UdpSocket::from_std(std_socket).context("l2 relay tokio socket failed")
}
#[allow(dead_code)]
fn ensure_ipv4(value: &str) -> Result<Ipv4Addr> {
value
.parse()
.map_err(|_| anyhow!("invalid ipv4 address {}", value))
}

2774
src/main.rs Normal file

File diff suppressed because it is too large Load diff

314
src/model.rs Normal file
View file

@ -0,0 +1,314 @@
use serde::{Deserialize, Serialize};
#[derive(Clone, Serialize, Deserialize)]
pub struct Route {
pub prefix: String,
pub kind: RouteKind,
pub enabled: bool,
#[serde(default)]
pub mapped_prefix: Option<String>,
}
#[derive(Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RouteKind {
Subnet,
Exit,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct NetworkInfo {
pub id: String,
pub name: String,
pub overlay_v4: String,
pub overlay_v6: String,
pub dns_domain: String,
#[serde(default)]
pub requires_approval: bool,
#[serde(default)]
pub key_rotation_max_age_seconds: Option<u64>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct NodeInfo {
pub id: String,
pub name: String,
pub dns_name: String,
pub ipv4: String,
pub ipv6: String,
pub wg_public_key: String,
pub machine_public_key: String,
pub endpoints: Vec<String>,
pub tags: Vec<String>,
pub routes: Vec<Route>,
pub last_seen: i64,
#[serde(default = "default_true")]
pub approved: bool,
#[serde(default)]
pub key_rotation_required: bool,
#[serde(default)]
pub revoked: bool,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct PeerInfo {
pub id: String,
pub name: String,
pub dns_name: String,
pub ipv4: String,
pub ipv6: String,
pub wg_public_key: String,
pub endpoints: Vec<String>,
pub tags: Vec<String>,
pub routes: Vec<Route>,
pub last_seen: i64,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct NetMap {
pub network: NetworkInfo,
pub node: NodeInfo,
pub peers: Vec<PeerInfo>,
pub relay: Option<RelayConfig>,
#[serde(default)]
pub probe_requests: Vec<ProbeRequest>,
pub generated_at: i64,
#[serde(default)]
pub revision: u64,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct ProbeRequest {
pub peer_id: String,
pub endpoints: Vec<String>,
pub ipv4: String,
pub ipv6: String,
pub requested_at: i64,
}
#[derive(Clone, Serialize, Deserialize, Default)]
pub struct RelayConfig {
pub stun_servers: Vec<String>,
pub turn_servers: Vec<String>,
#[serde(default)]
pub stream_relay_servers: Vec<String>,
#[serde(default)]
pub udp_relay_servers: Vec<String>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct RegisterRequest {
pub token: String,
pub node_name: String,
pub machine_public_key: String,
pub wg_public_key: String,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct RegisterResponse {
pub node_token: String,
pub netmap: NetMap,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct RegisterUrlRequest {
pub network_id: String,
pub node_name: String,
pub machine_public_key: String,
pub wg_public_key: String,
pub ttl_seconds: Option<u64>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct RegisterUrlResponse {
pub node_id: String,
pub network_id: String,
pub ipv4: String,
pub ipv6: String,
pub auth_path: String,
pub expires_at: i64,
pub node_token: String,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct EnrollmentToken {
pub token: String,
pub expires_at: i64,
pub uses_left: u32,
pub tags: Vec<String>,
pub revoked_at: Option<i64>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct CreateTokenRequest {
pub ttl_seconds: u64,
pub uses: u32,
pub tags: Vec<String>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct CreateTokenResponse {
pub token: EnrollmentToken,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct AdminNodesResponse {
pub nodes: Vec<NodeInfo>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct ApproveNodeResponse {
pub node_id: String,
pub approved: bool,
pub approved_at: Option<i64>,
}
fn default_true() -> bool {
true
}
#[derive(Clone, Serialize, Deserialize)]
pub struct HeartbeatRequest {
pub node_id: String,
pub endpoints: Vec<String>,
pub listen_port: Option<u16>,
pub routes: Vec<Route>,
#[serde(default)]
pub probe: Option<bool>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct HeartbeatResponse {
pub netmap: NetMap,
}
#[derive(Clone, Serialize, Deserialize, Default)]
pub struct AclPolicy {
#[serde(default)]
pub default_action: AclAction,
#[serde(default)]
pub rules: Vec<AclRule>,
}
#[derive(Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AclAction {
Allow,
Deny,
}
impl Default for AclAction {
fn default() -> Self {
Self::Allow
}
}
#[derive(Clone, Serialize, Deserialize, Default)]
pub struct AclSelector {
#[serde(default)]
pub any: bool,
#[serde(default)]
pub tags: Vec<String>,
#[serde(default)]
pub node_ids: Vec<String>,
#[serde(default)]
pub names: Vec<String>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct AclRule {
pub action: AclAction,
#[serde(default)]
pub src: AclSelector,
#[serde(default)]
pub dst: AclSelector,
}
#[derive(Clone, Serialize, Deserialize, Default)]
pub struct UpdateAclRequest {
pub policy: AclPolicy,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct UpdateAclResponse {
pub policy: AclPolicy,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct UpdateNodeRequest {
pub name: Option<String>,
pub tags: Option<Vec<String>>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct UpdateNodeResponse {
pub node: NodeInfo,
}
#[derive(Clone, Serialize, Deserialize, Default)]
pub struct KeyRotationPolicy {
#[serde(default)]
pub max_age_seconds: Option<u64>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct KeyPolicyResponse {
pub policy: KeyRotationPolicy,
}
#[derive(Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum KeyType {
Machine,
WireGuard,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct KeyRecord {
pub key_type: KeyType,
pub public_key: String,
pub created_at: i64,
#[serde(default)]
pub revoked_at: Option<i64>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct KeyRotationRequest {
pub machine_public_key: Option<String>,
pub wg_public_key: Option<String>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct KeyRotationResponse {
pub node_id: String,
pub machine_public_key: String,
pub wg_public_key: String,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct KeyHistoryResponse {
pub node_id: String,
pub keys: Vec<KeyRecord>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct RevokeNodeResponse {
pub node_id: String,
pub revoked_at: Option<i64>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct AuditEntry {
pub id: String,
pub timestamp: i64,
pub network_id: Option<String>,
pub node_id: Option<String>,
pub action: String,
#[serde(default)]
pub detail: Option<serde_json::Value>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct AuditLogResponse {
pub entries: Vec<AuditEntry>,
}

486
src/netlink.rs Normal file
View file

@ -0,0 +1,486 @@
use anyhow::{anyhow, Context, Result};
use ipnet::IpNet;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::time::Duration;
#[cfg(target_os = "linux")]
mod imp {
use super::*;
use futures_util::stream::TryStreamExt;
use netlink_packet_route::address::AddressAttribute;
use netlink_packet_route::{
link::LinkAttribute,
route::{RouteAddress, RouteAttribute, RouteMessage},
rule::{RuleAttribute, RuleUidRange},
AddressFamily,
};
use rtnetlink::{new_connection, Handle, LinkUnspec, RouteMessageBuilder};
use std::time::Instant;
use tokio::time::sleep;
#[derive(Clone)]
pub struct Netlink {
handle: Handle,
}
#[derive(Debug, Clone)]
pub struct RouteEntry {
pub prefix: IpNet,
pub oif: Option<u32>,
}
#[derive(Debug, Clone)]
pub struct InterfaceAddress {
pub addr: IpAddr,
#[allow(dead_code)]
pub prefix: u8,
}
impl Netlink {
pub async fn new() -> Result<Self> {
let (connection, handle, _) =
new_connection().context("failed to open netlink connection")?;
tokio::spawn(connection);
Ok(Netlink { handle })
}
pub async fn link_index(&self, name: &str) -> Result<Option<u32>> {
let mut links = self
.handle
.link()
.get()
.match_name(name.to_string())
.execute();
if let Some(link) = links.try_next().await? {
return Ok(Some(link.header.index));
}
Ok(None)
}
pub async fn link_name(&self, index: u32) -> Result<Option<String>> {
let mut links = self.handle.link().get().match_index(index).execute();
if let Some(link) = links.try_next().await? {
for attr in link.attributes {
if let LinkAttribute::IfName(name) = attr {
return Ok(Some(name));
}
}
}
Ok(None)
}
pub async fn wait_for_link(&self, name: &str, timeout: Duration) -> Result<u32> {
let start = Instant::now();
loop {
if let Some(index) = self.link_index(name).await? {
return Ok(index);
}
if start.elapsed() > timeout {
return Err(anyhow!("interface {} did not appear", name));
}
sleep(Duration::from_millis(100)).await;
}
}
pub async fn set_link_up(&self, index: u32) -> Result<()> {
let link = LinkUnspec::new_with_index(index).up().build();
self.handle.link().set(link).execute().await?;
Ok(())
}
pub async fn interface_addresses(&self, index: u32) -> Result<Vec<InterfaceAddress>> {
let mut addresses = self
.handle
.address()
.get()
.set_link_index_filter(index)
.execute();
let mut results = Vec::new();
while let Some(msg) = addresses.try_next().await? {
let mut selected = None;
for attr in msg.attributes {
match attr {
AddressAttribute::Local(addr) => {
selected = Some(addr);
break;
}
AddressAttribute::Address(addr) => {
if selected.is_none() {
selected = Some(addr);
}
}
_ => {}
}
}
let Some(addr) = selected else {
continue;
};
results.push(InterfaceAddress {
addr,
prefix: msg.header.prefix_len,
});
}
Ok(results)
}
pub async fn replace_address(
&self,
index: u32,
address: IpAddr,
prefix: u8,
) -> Result<()> {
self.handle
.address()
.add(index, address, prefix)
.replace()
.execute()
.await?;
Ok(())
}
pub async fn replace_route(&self, prefix: IpNet, index: u32) -> Result<()> {
self.replace_route_with_metric(prefix, index, None).await
}
pub async fn replace_route_with_metric(
&self,
prefix: IpNet,
index: u32,
metric: Option<u32>,
) -> Result<()> {
match prefix {
IpNet::V4(net) => {
let mut builder = RouteMessageBuilder::<Ipv4Addr>::new()
.destination_prefix(net.network(), net.prefix_len())
.output_interface(index);
if let Some(metric) = metric {
builder = builder.priority(metric);
}
let route = builder.build();
self.handle.route().add(route).replace().execute().await?;
}
IpNet::V6(net) => {
let mut builder = RouteMessageBuilder::<Ipv6Addr>::new()
.destination_prefix(net.network(), net.prefix_len())
.output_interface(index);
if let Some(metric) = metric {
builder = builder.priority(metric);
}
let route = builder.build();
self.handle.route().add(route).replace().execute().await?;
}
}
Ok(())
}
pub async fn replace_route_with_metric_table(
&self,
prefix: IpNet,
index: u32,
metric: Option<u32>,
table: u32,
) -> Result<()> {
match prefix {
IpNet::V4(net) => {
let mut builder = RouteMessageBuilder::<Ipv4Addr>::new()
.destination_prefix(net.network(), net.prefix_len())
.output_interface(index)
.table_id(table);
if let Some(metric) = metric {
builder = builder.priority(metric);
}
let route = builder.build();
self.handle.route().add(route).replace().execute().await?;
}
IpNet::V6(net) => {
let mut builder = RouteMessageBuilder::<Ipv6Addr>::new()
.destination_prefix(net.network(), net.prefix_len())
.output_interface(index)
.table_id(table);
if let Some(metric) = metric {
builder = builder.priority(metric);
}
let route = builder.build();
self.handle.route().add(route).replace().execute().await?;
}
}
Ok(())
}
pub async fn add_rule_for_prefix(
&self,
prefix: IpNet,
table: u32,
priority: u32,
) -> Result<()> {
match prefix {
IpNet::V4(net) => {
self.handle
.rule()
.add()
.table_id(table)
.priority(priority)
.v4()
.destination_prefix(net.network(), net.prefix_len())
.replace()
.execute()
.await?;
}
IpNet::V6(net) => {
self.handle
.rule()
.add()
.table_id(table)
.priority(priority)
.v6()
.destination_prefix(net.network(), net.prefix_len())
.replace()
.execute()
.await?;
}
}
Ok(())
}
pub async fn add_uid_rule_v4(
&self,
table: u32,
priority: u32,
start: u32,
end: u32,
) -> Result<()> {
let mut req = self
.handle
.rule()
.add()
.table_id(table)
.priority(priority)
.v4()
.replace();
req.message_mut()
.attributes
.push(RuleAttribute::UidRange(RuleUidRange { start, end }));
req.execute().await?;
Ok(())
}
pub async fn add_uid_rule_v6(
&self,
table: u32,
priority: u32,
start: u32,
end: u32,
) -> Result<()> {
let mut req = self
.handle
.rule()
.add()
.table_id(table)
.priority(priority)
.v6()
.replace();
req.message_mut()
.attributes
.push(RuleAttribute::UidRange(RuleUidRange { start, end }));
req.execute().await?;
Ok(())
}
pub async fn delete_link(&self, name: &str) -> Result<()> {
let mut links = self
.handle
.link()
.get()
.match_name(name.to_string())
.execute();
if let Some(link) = links.try_next().await? {
self.handle.link().del(link.header.index).execute().await?;
}
Ok(())
}
pub async fn list_routes(&self) -> Result<Vec<RouteEntry>> {
let mut entries = Vec::new();
entries.extend(self.list_routes_v4().await?);
entries.extend(self.list_routes_v6().await?);
Ok(entries)
}
async fn list_routes_v4(&self) -> Result<Vec<RouteEntry>> {
let mut entries = Vec::new();
let route = RouteMessageBuilder::<Ipv4Addr>::new().build();
let mut routes = self.handle.route().get(route).execute();
while let Some(route) = routes.try_next().await? {
if let Some(entry) = parse_route_message(route) {
entries.push(entry);
}
}
Ok(entries)
}
async fn list_routes_v6(&self) -> Result<Vec<RouteEntry>> {
let mut entries = Vec::new();
let route = RouteMessageBuilder::<Ipv6Addr>::new().build();
let mut routes = self.handle.route().get(route).execute();
while let Some(route) = routes.try_next().await? {
if let Some(entry) = parse_route_message(route) {
entries.push(entry);
}
}
Ok(entries)
}
}
fn parse_route_message(route: RouteMessage) -> Option<RouteEntry> {
let family = route.header.address_family;
let prefix_len = route.header.destination_prefix_length;
let mut destination = None;
let mut oif = None;
for attr in route.attributes {
match attr {
RouteAttribute::Destination(RouteAddress::Inet(addr)) => {
destination = Some(IpAddr::V4(addr));
}
RouteAttribute::Destination(RouteAddress::Inet6(addr)) => {
destination = Some(IpAddr::V6(addr));
}
RouteAttribute::Oif(index) => {
oif = Some(index);
}
_ => {}
}
}
let addr = match (destination, family) {
(Some(addr), _) => addr,
(None, AddressFamily::Inet) => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
(None, AddressFamily::Inet6) => IpAddr::V6(Ipv6Addr::UNSPECIFIED),
_ => return None,
};
let prefix = IpNet::new(addr, prefix_len).ok()?;
Some(RouteEntry { prefix, oif })
}
}
#[cfg(target_os = "linux")]
pub use imp::{InterfaceAddress, Netlink, RouteEntry};
#[cfg(not(target_os = "linux"))]
mod imp {
use super::*;
#[derive(Clone)]
pub struct Netlink;
#[derive(Debug, Clone)]
pub struct RouteEntry {
pub prefix: IpNet,
pub oif: Option<u32>,
}
#[derive(Debug, Clone)]
pub struct InterfaceAddress {
pub addr: IpAddr,
pub prefix: u8,
}
impl Netlink {
pub async fn new() -> Result<Self> {
Err(anyhow!("netlink is only supported on linux"))
}
pub async fn link_index(&self, _name: &str) -> Result<Option<u32>> {
Err(anyhow!("netlink is only supported on linux"))
}
pub async fn link_name(&self, _index: u32) -> Result<Option<String>> {
Err(anyhow!("netlink is only supported on linux"))
}
pub async fn wait_for_link(&self, _name: &str, _timeout: Duration) -> Result<u32> {
Err(anyhow!("netlink is only supported on linux"))
}
pub async fn set_link_up(&self, _index: u32) -> Result<()> {
Err(anyhow!("netlink is only supported on linux"))
}
pub async fn interface_addresses(&self, _index: u32) -> Result<Vec<InterfaceAddress>> {
Err(anyhow!("netlink is only supported on linux"))
}
pub async fn replace_address(
&self,
_index: u32,
_address: IpAddr,
_prefix: u8,
) -> Result<()> {
Err(anyhow!("netlink is only supported on linux"))
}
pub async fn replace_route(&self, _prefix: IpNet, _index: u32) -> Result<()> {
Err(anyhow!("netlink is only supported on linux"))
}
pub async fn replace_route_with_metric(
&self,
_prefix: IpNet,
_index: u32,
_metric: Option<u32>,
) -> Result<()> {
Err(anyhow!("netlink is only supported on linux"))
}
pub async fn replace_route_with_metric_table(
&self,
_prefix: IpNet,
_index: u32,
_metric: Option<u32>,
_table: u32,
) -> Result<()> {
Err(anyhow!("netlink is only supported on linux"))
}
pub async fn add_rule_for_prefix(
&self,
_prefix: IpNet,
_table: u32,
_priority: u32,
) -> Result<()> {
Err(anyhow!("netlink is only supported on linux"))
}
pub async fn add_uid_rule_v4(
&self,
_table: u32,
_priority: u32,
_start: u32,
_end: u32,
) -> Result<()> {
Err(anyhow!("netlink is only supported on linux"))
}
pub async fn add_uid_rule_v6(
&self,
_table: u32,
_priority: u32,
_start: u32,
_end: u32,
) -> Result<()> {
Err(anyhow!("netlink is only supported on linux"))
}
pub async fn delete_link(&self, _name: &str) -> Result<()> {
Err(anyhow!("netlink is only supported on linux"))
}
pub async fn list_routes(&self) -> Result<Vec<RouteEntry>> {
Err(anyhow!("netlink is only supported on linux"))
}
}
}
#[cfg(not(target_os = "linux"))]
pub use imp::{InterfaceAddress, Netlink, RouteEntry};

178
src/relay_tunnel.rs Normal file
View file

@ -0,0 +1,178 @@
use crate::model::PeerInfo;
use crate::stream_relay;
use anyhow::Result;
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tokio::net::{TcpStream, UdpSocket};
use tokio::task::JoinHandle;
use tokio::time::{sleep, Duration};
const DATA_MAGIC: &[u8; 4] = b"LSDP";
const RECONNECT_DELAY: Duration = Duration::from_secs(2);
pub struct RelayTunnelManager {
node_id: String,
servers: Vec<String>,
wg_listen_port: u16,
relay_ip: Option<IpAddr>,
tunnels: HashMap<String, RelayTunnel>,
}
struct RelayTunnel {
local_addr: SocketAddr,
_task: JoinHandle<()>,
}
impl RelayTunnelManager {
pub fn new(
node_id: String,
servers: Vec<String>,
wg_listen_port: u16,
relay_ip: Option<IpAddr>,
) -> Self {
Self {
node_id,
servers,
wg_listen_port,
relay_ip,
tunnels: HashMap::new(),
}
}
pub async fn ensure_for_peers(
&mut self,
peers: &[PeerInfo],
) -> Result<HashMap<String, SocketAddr>> {
let mut endpoints = HashMap::new();
for peer in peers {
let addr = self.ensure_peer(&peer.id).await?;
endpoints.insert(peer.id.clone(), addr);
}
Ok(endpoints)
}
async fn ensure_peer(&mut self, peer_id: &str) -> Result<SocketAddr> {
if let Some(tunnel) = self.tunnels.get(peer_id) {
return Ok(tunnel.local_addr);
}
let mut relay_ip = self
.relay_ip
.unwrap_or(IpAddr::V4(Ipv4Addr::LOCALHOST));
let socket = match UdpSocket::bind(SocketAddr::new(relay_ip, 0)).await {
Ok(socket) => socket,
Err(_) => {
relay_ip = IpAddr::V4(Ipv4Addr::LOCALHOST);
UdpSocket::bind(SocketAddr::new(relay_ip, 0)).await?
}
};
let local_addr = SocketAddr::new(relay_ip, socket.local_addr()?.port());
let node_id = self.node_id.clone();
let servers = self.servers.clone();
let peer_id_owned = peer_id.to_string();
let wg_listen_port = self.wg_listen_port;
let task = tokio::spawn(async move {
run_tunnel(node_id, peer_id_owned, servers, socket, wg_listen_port).await;
});
self.tunnels.insert(
peer_id.to_string(),
RelayTunnel {
local_addr,
_task: task,
},
);
Ok(local_addr)
}
}
async fn run_tunnel(
node_id: String,
peer_id: String,
servers: Vec<String>,
socket: UdpSocket,
wg_listen_port: u16,
) {
if servers.is_empty() {
eprintln!("stream relay tunnel missing servers");
return;
}
let wg_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), wg_listen_port);
let mut buf = vec![0u8; 65535];
let mut server_index: usize = 0;
loop {
let server = servers[server_index % servers.len()].clone();
match TcpStream::connect(&server).await {
Ok(mut stream) => {
if let Err(err) = stream_relay::write_register(&mut stream, &node_id).await {
eprintln!("stream relay register failed: {}", err);
sleep(RECONNECT_DELAY).await;
continue;
}
eprintln!("relay tunnel {} connected to {} for {}", node_id, server, peer_id);
let mut saw_send = false;
let mut saw_recv = false;
loop {
tokio::select! {
recv = socket.recv_from(&mut buf) => {
let (len, _) = match recv {
Ok(data) => data,
Err(_) => break,
};
if !saw_send {
eprintln!("relay tunnel {} -> {} forwarding {} bytes", node_id, peer_id, len);
saw_send = true;
}
let payload = wrap_data(&buf[..len]);
if stream_relay::write_send(&mut stream, &node_id, &peer_id, &payload).await.is_err() {
break;
}
}
deliver = stream_relay::read_deliver(&mut stream) => {
let delivered = match deliver {
Ok(Some(data)) => data,
Ok(None) => continue,
Err(_) => break,
};
if delivered.0 != peer_id {
continue;
}
let payload = match unwrap_data(&delivered.1) {
Some(payload) => payload,
None => continue,
};
if !saw_recv {
eprintln!("relay tunnel {} <- {} received {} bytes", node_id, peer_id, payload.len());
saw_recv = true;
}
let _ = socket.send_to(payload, wg_addr).await;
}
}
}
}
Err(err) => {
eprintln!("stream relay tunnel connect failed: {}", err);
}
}
server_index = server_index.wrapping_add(1);
sleep(RECONNECT_DELAY).await;
}
}
fn wrap_data(payload: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(DATA_MAGIC.len() + payload.len());
out.extend_from_slice(DATA_MAGIC);
out.extend_from_slice(payload);
out
}
fn unwrap_data(payload: &[u8]) -> Option<&[u8]> {
if payload.starts_with(DATA_MAGIC) {
Some(&payload[DATA_MAGIC.len()..])
} else {
None
}
}

167
src/router.rs Normal file
View file

@ -0,0 +1,167 @@
use crate::firewall;
use crate::netlink::{InterfaceAddress, Netlink, RouteEntry};
use anyhow::{anyhow, Context, Result};
use ipnet::IpNet;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::path::Path;
pub async fn resolve_out_interface(out_interface: Option<String>) -> Result<String> {
if let Some(name) = out_interface {
return Ok(name);
}
default_out_interface().await
}
pub async fn enable_forwarding(wg_interface: &str, out_interface: &str, snat: bool) -> Result<()> {
write_sysctl("/proc/sys/net/ipv4/ip_forward", "1")?;
write_sysctl("/proc/sys/net/ipv6/conf/all/forwarding", "1")?;
let netlink = Netlink::new().await?;
netlink
.link_index(wg_interface)
.await?
.ok_or_else(|| anyhow!("interface {} not found", wg_interface))?;
netlink
.link_index(out_interface)
.await?
.ok_or_else(|| anyhow!("interface {} not found", out_interface))?;
firewall::reset_tables()?;
firewall::apply_forwarding_rules(wg_interface, out_interface)?;
if snat {
firewall::apply_snat(out_interface)?;
}
Ok(())
}
pub async fn disable_forwarding(_wg_interface: &str, _out_interface: &str) -> Result<()> {
firewall::reset_tables()
}
pub async fn apply_route_maps(
wg_interface: &str,
out_interface: &str,
maps: &[(String, String)],
) -> Result<()> {
let mut parsed = Vec::new();
for (real, mapped) in maps {
let real_net: IpNet = real
.parse()
.with_context(|| format!("invalid route map prefix {}", real))?;
let mapped_net: IpNet = mapped
.parse()
.with_context(|| format!("invalid route map prefix {}", mapped))?;
let real_v4 = matches!(real_net, IpNet::V4(_));
let mapped_v4 = matches!(mapped_net, IpNet::V4(_));
if real_v4 != mapped_v4 {
return Err(anyhow!(
"route map ip versions must match ({} vs {})",
real,
mapped
));
}
if real_net.prefix_len() != mapped_net.prefix_len() {
return Err(anyhow!(
"route map prefix lengths must match ({} vs {})",
real,
mapped
));
}
parsed.push((real_net, mapped_net));
}
firewall::apply_netmap(wg_interface, out_interface, &parsed)?;
Ok(())
}
pub async fn interface_ips(out_interface: &str) -> Result<(Option<String>, Option<String>)> {
let netlink = Netlink::new().await?;
let index = netlink
.link_index(out_interface)
.await?
.ok_or_else(|| anyhow!("interface {} not found", out_interface))?;
let addrs = netlink.interface_addresses(index).await?;
let mut v4 = None;
let mut v6 = None;
for InterfaceAddress { addr, .. } in addrs {
match addr {
IpAddr::V4(ip) => {
if v4.is_none() && is_usable_ipv4(ip) {
v4 = Some(ip.to_string());
}
}
IpAddr::V6(ip) => {
if v6.is_none() && is_usable_ipv6(ip) {
v6 = Some(ip.to_string());
}
}
}
}
Ok((v4, v6))
}
async fn default_out_interface() -> Result<String> {
let netlink = Netlink::new().await?;
let routes = netlink.list_routes().await?;
let index = find_default_oif(&routes).ok_or_else(|| anyhow!("failed to detect default route interface"))?;
let name = netlink
.link_name(index)
.await?
.ok_or_else(|| anyhow!("default route interface not found"))?;
Ok(name)
}
fn find_default_oif(routes: &[RouteEntry]) -> Option<u32> {
for entry in routes {
if let IpNet::V4(net) = entry.prefix {
if net.prefix_len() == 0 {
if let Some(oif) = entry.oif {
return Some(oif);
}
}
}
}
for entry in routes {
if let IpNet::V6(net) = entry.prefix {
if net.prefix_len() == 0 {
if let Some(oif) = entry.oif {
return Some(oif);
}
}
}
}
None
}
fn write_sysctl(path: &str, value: &str) -> Result<()> {
if let Some(parent) = Path::new(path).parent() {
std::fs::create_dir_all(parent).ok();
}
std::fs::write(path, value)
.with_context(|| format!("failed to write sysctl {}", path))?;
Ok(())
}
fn is_usable_ipv4(ip: Ipv4Addr) -> bool {
if ip.is_loopback() {
return false;
}
let octets = ip.octets();
if octets[0] == 169 && octets[1] == 254 {
return false;
}
true
}
fn is_usable_ipv6(ip: Ipv6Addr) -> bool {
if ip.is_loopback() {
return false;
}
let seg0 = ip.segments()[0];
if (seg0 & 0xffc0) == 0xfe80 {
return false;
}
true
}

420
src/routes.rs Normal file
View file

@ -0,0 +1,420 @@
use crate::model::{NetMap, Route, RouteKind};
use crate::netlink::{Netlink, RouteEntry};
use anyhow::{anyhow, Result};
use ipnet::IpNet;
use std::collections::{HashMap, HashSet};
pub struct RouteApplyConfig {
pub interface: String,
pub accept_exit_node: bool,
pub exit_node_id: Option<String>,
pub exit_node_name: Option<String>,
pub exit_node_policy: ExitNodePolicy,
pub exit_node_tag: Option<String>,
pub exit_node_metric_base: u32,
pub exit_node_uid_range: Option<UidRange>,
pub allow_conflicts: bool,
pub route_table: Option<u32>,
pub route_rule_priority: u32,
pub exit_rule_priority: u32,
pub exit_uid_rule_priority: u32,
}
#[derive(Clone, Copy, Debug)]
pub struct UidRange {
pub start: u32,
pub end: u32,
}
#[derive(Clone, Copy, Debug)]
pub enum ExitNodePolicy {
First,
Latest,
Multi,
}
pub async fn apply_advertised_routes(netmap: &NetMap, cfg: &RouteApplyConfig) -> Result<()> {
let netlink = Netlink::new().await?;
let interface_index = netlink
.link_index(&cfg.interface)
.await?
.ok_or_else(|| anyhow!("interface {} not found", cfg.interface))?;
let existing_routes = netlink.list_routes().await?;
let selected_exit_peers = select_exit_peers(netmap, cfg);
let selected_exit_ids: HashSet<String> =
selected_exit_peers.iter().map(|peer| peer.peer_id.clone()).collect();
let selected_exit_metrics: HashMap<String, u32> = selected_exit_peers
.iter()
.filter_map(|peer| peer.metric.map(|metric| (peer.peer_id.clone(), metric)))
.collect();
let exit_requested = cfg.exit_node_id.is_some() || cfg.exit_node_name.is_some();
let tag_filtered = cfg.exit_node_tag.is_some();
if exit_requested && selected_exit_peers.is_empty() {
eprintln!("requested exit node not found; skipping exit routes");
}
if tag_filtered && selected_exit_peers.is_empty() {
eprintln!("exit node tag filter matched no peers; skipping exit routes");
}
let allow_exit_routes = if exit_requested || tag_filtered {
!selected_exit_peers.is_empty()
} else {
true
};
let allow_multiple_exit = matches!(cfg.exit_node_policy, ExitNodePolicy::Multi);
let mut exit_v4_applied = false;
let mut exit_v6_applied = false;
let mut conflict_count = 0;
let mut skipped_exit = false;
let mut applied_routes: Vec<IpNet> = Vec::new();
let mut exit_uid_rule_v4 = false;
let mut exit_uid_rule_v6 = false;
for peer in &netmap.peers {
let is_exit_peer = selected_exit_ids.is_empty() || selected_exit_ids.contains(&peer.id);
let exit_metric = selected_exit_metrics.get(&peer.id).cloned();
for route in &peer.routes {
if !route.enabled {
continue;
}
let apply_prefix = match route_apply_prefix(route) {
Ok(prefix) => prefix,
Err(err) => {
eprintln!(
"skipping route {} for peer {}: {}",
route.prefix, peer.id, err
);
continue;
}
};
match route.kind {
RouteKind::Subnet => {
if route_conflicts(apply_prefix, &existing_routes, interface_index)
|| route_conflicts_with_applied(apply_prefix, &applied_routes)
{
conflict_count += 1;
if !cfg.allow_conflicts {
continue;
}
}
let net = apply_route(
apply_prefix,
interface_index,
&netlink,
None,
cfg.route_table,
)
.await?;
applied_routes.push(net);
if let Some(table) = cfg.route_table {
netlink
.add_rule_for_prefix(net, table, cfg.route_rule_priority)
.await?;
}
}
RouteKind::Exit => {
if !cfg.accept_exit_node || !allow_exit_routes {
continue;
}
if !is_exit_peer {
skipped_exit = true;
continue;
}
if is_ipv6(apply_prefix) {
if exit_v6_applied && !allow_multiple_exit {
continue;
}
let net = apply_route(
apply_prefix,
interface_index,
&netlink,
exit_metric,
cfg.route_table,
)
.await?;
applied_routes.push(net);
exit_v6_applied = true;
if let Some(table) = cfg.route_table {
if let Some(uid_range) = cfg.exit_node_uid_range {
if !exit_uid_rule_v6 {
netlink
.add_uid_rule_v6(
table,
cfg.exit_uid_rule_priority,
uid_range.start,
uid_range.end,
)
.await?;
exit_uid_rule_v6 = true;
}
} else {
netlink
.add_rule_for_prefix(net, table, cfg.exit_rule_priority)
.await?;
}
}
} else {
if exit_v4_applied && !allow_multiple_exit {
continue;
}
let net = apply_route(
apply_prefix,
interface_index,
&netlink,
exit_metric,
cfg.route_table,
)
.await?;
applied_routes.push(net);
exit_v4_applied = true;
if let Some(table) = cfg.route_table {
if let Some(uid_range) = cfg.exit_node_uid_range {
if !exit_uid_rule_v4 {
netlink
.add_uid_rule_v4(
table,
cfg.exit_uid_rule_priority,
uid_range.start,
uid_range.end,
)
.await?;
exit_uid_rule_v4 = true;
}
} else {
netlink
.add_rule_for_prefix(net, table, cfg.exit_rule_priority)
.await?;
}
}
}
}
}
}
}
if conflict_count > 0 {
eprintln!(
"skipped {} conflicting route(s) (use --allow-route-conflicts to force)",
conflict_count
);
}
if skipped_exit {
eprintln!(
"exit node selection active; routes from other exit nodes were skipped"
);
}
Ok(())
}
pub fn selected_exit_peer_ids(netmap: &NetMap, cfg: &RouteApplyConfig) -> HashSet<String> {
if !cfg.accept_exit_node {
return HashSet::new();
}
let selected = select_exit_peers(netmap, cfg);
let exit_requested = cfg.exit_node_id.is_some() || cfg.exit_node_name.is_some();
let tag_filtered = cfg.exit_node_tag.is_some();
let allow_exit_routes = if exit_requested || tag_filtered {
!selected.is_empty()
} else {
true
};
if !allow_exit_routes {
return HashSet::new();
}
selected
.into_iter()
.map(|peer| peer.peer_id)
.collect()
}
fn route_apply_prefix(route: &Route) -> Result<&str> {
let Some(mapped) = route.mapped_prefix.as_deref() else {
return Ok(&route.prefix);
};
let real_net: IpNet = route.prefix.parse()?;
let mapped_net: IpNet = mapped.parse()?;
let real_v4 = matches!(real_net, IpNet::V4(_));
let mapped_v4 = matches!(mapped_net, IpNet::V4(_));
if real_v4 != mapped_v4 {
return Err(anyhow!("mapped prefix ip version mismatch"));
}
if real_net.prefix_len() != mapped_net.prefix_len() {
return Err(anyhow!("mapped prefix length mismatch"));
}
Ok(mapped)
}
struct ExitPeerSelection {
peer_id: String,
metric: Option<u32>,
}
fn select_exit_peers(netmap: &NetMap, cfg: &RouteApplyConfig) -> Vec<ExitPeerSelection> {
let mut candidates: Vec<&crate::model::PeerInfo> = netmap
.peers
.iter()
.filter(|peer| {
peer.routes
.iter()
.any(|route| matches!(route.kind, RouteKind::Exit))
})
.collect();
if let Some(tag) = cfg.exit_node_tag.as_ref() {
candidates.retain(|peer| peer.tags.iter().any(|peer_tag| peer_tag == tag));
}
if let Some(id) = cfg.exit_node_id.as_ref() {
return candidates
.into_iter()
.find(|peer| &peer.id == id)
.map(|peer| vec![ExitPeerSelection {
peer_id: peer.id.clone(),
metric: None,
}])
.unwrap_or_default();
}
if let Some(name) = cfg.exit_node_name.as_ref() {
return candidates
.into_iter()
.find(|peer| peer.name == *name)
.map(|peer| vec![ExitPeerSelection {
peer_id: peer.id.clone(),
metric: None,
}])
.unwrap_or_default();
}
match cfg.exit_node_policy {
ExitNodePolicy::Latest => {
candidates.sort_by_key(|peer| peer.last_seen);
candidates
.last()
.map(|peer| ExitPeerSelection {
peer_id: peer.id.clone(),
metric: None,
})
.into_iter()
.collect()
}
ExitNodePolicy::Multi => candidates
.into_iter()
.enumerate()
.map(|(idx, peer)| ExitPeerSelection {
peer_id: peer.id.clone(),
metric: Some(cfg.exit_node_metric_base.saturating_add(idx as u32)),
})
.collect(),
ExitNodePolicy::First => candidates
.into_iter()
.next()
.map(|peer| ExitPeerSelection {
peer_id: peer.id.clone(),
metric: None,
})
.into_iter()
.collect(),
}
}
async fn apply_route(
prefix: &str,
interface_index: u32,
netlink: &Netlink,
metric: Option<u32>,
table: Option<u32>,
) -> Result<IpNet> {
let net: IpNet = prefix.parse()?;
match table {
Some(table) => {
netlink
.replace_route_with_metric_table(net, interface_index, metric, table)
.await?;
}
None => {
netlink
.replace_route_with_metric(net, interface_index, metric)
.await?;
}
}
Ok(net)
}
fn route_conflicts(prefix: &str, existing: &[RouteEntry], interface_index: u32) -> bool {
let Ok(net) = prefix.parse::<IpNet>() else {
return false;
};
existing.iter().any(|route| {
if route.oif == Some(interface_index) {
return false;
}
if route.prefix.prefix_len() == 0 {
return false;
}
nets_overlap(&net, &route.prefix)
})
}
fn nets_overlap(a: &IpNet, b: &IpNet) -> bool {
match (a, b) {
(IpNet::V4(a4), IpNet::V4(b4)) => ranges_overlap(v4_range(a4), v4_range(b4)),
(IpNet::V6(a6), IpNet::V6(b6)) => ranges_overlap(v6_range(a6), v6_range(b6)),
_ => false,
}
}
fn v4_range(net: &ipnet::Ipv4Net) -> (u64, u64) {
let base = u64::from(u32::from(net.network()));
let host_bits = 32u32.saturating_sub(net.prefix_len() as u32);
let end = if host_bits == 32 {
u64::from(u32::MAX)
} else {
base + ((1u64 << host_bits) - 1)
};
(base, end)
}
fn v6_range(net: &ipnet::Ipv6Net) -> (u128, u128) {
let base = u128::from(net.network());
let host_bits = 128u32.saturating_sub(net.prefix_len() as u32);
let end = if host_bits == 128 {
u128::MAX
} else {
base + ((1u128 << host_bits) - 1)
};
(base, end)
}
fn ranges_overlap<T: Ord>(a: (T, T), b: (T, T)) -> bool {
a.0 <= b.1 && b.0 <= a.1
}
fn route_conflicts_with_applied(prefix: &str, applied: &[IpNet]) -> bool {
let Ok(net) = prefix.parse::<IpNet>() else {
return false;
};
applied.iter().any(|other| nets_overlap(&net, other))
}
fn is_ipv6(prefix: &str) -> bool {
prefix.contains(':')
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn overlaps_detected_for_subnets() {
let a: IpNet = "10.0.0.0/24".parse().unwrap();
let b: IpNet = "10.0.0.128/25".parse().unwrap();
assert!(nets_overlap(&a, &b));
}
#[test]
fn applied_conflict_detects_overlap() {
let applied: Vec<IpNet> = vec!["10.1.0.0/24".parse().unwrap()];
assert!(route_conflicts_with_applied("10.1.0.128/25", &applied));
}
}

49
src/state.rs Normal file
View file

@ -0,0 +1,49 @@
use crate::model::NetMap;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
#[derive(Clone, Serialize, Deserialize)]
pub struct ClientState {
pub profile: String,
pub network_id: String,
pub node_id: String,
pub node_name: String,
pub machine_private_key: String,
pub machine_public_key: String,
pub wg_private_key: String,
pub wg_public_key: String,
#[serde(default)]
pub node_token: Option<String>,
pub ipv4: String,
pub ipv6: String,
pub last_netmap: Option<NetMap>,
pub updated_at: i64,
}
pub fn default_state_dir(profile: &str) -> Option<PathBuf> {
// Keep state isolated per profile to allow future multi-network support.
dirs::data_dir().map(|dir| dir.join("lightscale").join(profile))
}
pub fn state_path(state_dir: &Path) -> PathBuf {
state_dir.join("state.json")
}
pub fn load_state(path: &Path) -> Result<Option<ClientState>> {
match std::fs::read_to_string(path) {
Ok(contents) => Ok(Some(serde_json::from_str(&contents)?)),
Err(_) => Ok(None),
}
}
pub fn save_state(path: &Path, state: &ClientState) -> Result<()> {
if let Some(parent) = path.parent() {
if !parent.as_os_str().is_empty() {
std::fs::create_dir_all(parent)?;
}
}
let json = serde_json::to_string_pretty(state)?;
std::fs::write(path, json)?;
Ok(())
}

96
src/stream_relay.rs Normal file
View file

@ -0,0 +1,96 @@
use anyhow::{anyhow, Result};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
const MAGIC: &[u8; 4] = b"LSR2";
const TYPE_REGISTER: u8 = 1;
const TYPE_SEND: u8 = 2;
const TYPE_DELIVER: u8 = 3;
const HEADER_LEN: usize = 8;
const MAX_ID_LEN: usize = 64;
const MAX_FRAME_LEN: usize = 64 * 1024;
pub async fn write_register<W: AsyncWrite + Unpin>(stream: &mut W, node_id: &str) -> Result<()> {
let packet = build_packet(TYPE_REGISTER, node_id, "", &[])?;
write_frame(stream, &packet).await
}
pub async fn write_send<W: AsyncWrite + Unpin>(
stream: &mut W,
from_id: &str,
to_id: &str,
payload: &[u8],
) -> Result<()> {
let packet = build_packet(TYPE_SEND, from_id, to_id, payload)?;
write_frame(stream, &packet).await
}
pub async fn read_deliver<R: AsyncRead + Unpin>(
stream: &mut R,
) -> Result<Option<(String, Vec<u8>)>> {
let frame = read_frame(stream).await?;
Ok(parse_deliver(&frame))
}
async fn read_frame<R: AsyncRead + Unpin>(stream: &mut R) -> Result<Vec<u8>> {
let mut len_buf = [0u8; 4];
stream.read_exact(&mut len_buf).await?;
let len = u32::from_be_bytes(len_buf) as usize;
if len == 0 || len > MAX_FRAME_LEN {
return Err(anyhow!("invalid frame length {}", len));
}
let mut buf = vec![0u8; len];
stream.read_exact(&mut buf).await?;
Ok(buf)
}
async fn write_frame<W: AsyncWrite + Unpin>(stream: &mut W, body: &[u8]) -> Result<()> {
if body.is_empty() || body.len() > MAX_FRAME_LEN {
return Err(anyhow!("invalid frame length {}", body.len()));
}
let len = body.len() as u32;
stream.write_all(&len.to_be_bytes()).await?;
stream.write_all(body).await?;
Ok(())
}
fn parse_deliver(buf: &[u8]) -> Option<(String, Vec<u8>)> {
if buf.len() < HEADER_LEN {
return None;
}
if &buf[0..4] != MAGIC {
return None;
}
let msg_type = buf[4];
if msg_type != TYPE_DELIVER {
return None;
}
let from_len = buf[5] as usize;
let to_len = buf[6] as usize;
if from_len > MAX_ID_LEN || to_len > MAX_ID_LEN || to_len != 0 {
return None;
}
let offset = HEADER_LEN;
if buf.len() < offset + from_len + to_len {
return None;
}
let from_end = offset + from_len;
let from_id = std::str::from_utf8(&buf[offset..from_end]).ok()?.to_string();
let payload = buf[from_end..].to_vec();
Some((from_id, payload))
}
fn build_packet(msg_type: u8, from_id: &str, to_id: &str, payload: &[u8]) -> Result<Vec<u8>> {
if from_id.len() > MAX_ID_LEN || to_id.len() > MAX_ID_LEN {
return Err(anyhow!("relay id too long"));
}
let mut buf = Vec::with_capacity(HEADER_LEN + from_id.len() + to_id.len() + payload.len());
buf.extend_from_slice(MAGIC);
buf.push(msg_type);
buf.push(from_id.len() as u8);
buf.push(to_id.len() as u8);
buf.push(0);
buf.extend_from_slice(from_id.as_bytes());
buf.extend_from_slice(to_id.as_bytes());
buf.extend_from_slice(payload);
Ok(buf)
}

170
src/stun.rs Normal file
View file

@ -0,0 +1,170 @@
use anyhow::{anyhow, Context, Result};
use rand::RngCore;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs, UdpSocket};
use std::time::Duration;
const MAGIC_COOKIE: u32 = 0x2112A442;
const BINDING_REQUEST: u16 = 0x0001;
const BINDING_SUCCESS: u16 = 0x0101;
const ATTR_MAPPED_ADDRESS: u16 = 0x0001;
const ATTR_XOR_MAPPED_ADDRESS: u16 = 0x0020;
pub fn discover_endpoint(servers: &[String], bind_port: u16, timeout: Duration) -> Result<SocketAddr> {
let mut last_err: Option<anyhow::Error> = None;
for server in servers {
match discover_endpoint_one(server, bind_port, timeout) {
Ok(addr) => return Ok(addr),
Err(err) => last_err = Some(err),
}
}
Err(last_err.unwrap_or_else(|| anyhow!("no stun servers provided")))
}
fn discover_endpoint_one(server: &str, bind_port: u16, timeout: Duration) -> Result<SocketAddr> {
let server_addr = resolve_server(server)?;
let bind_addr = match server_addr {
SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), bind_port),
SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), bind_port),
};
let socket = UdpSocket::bind(bind_addr).context("failed to bind stun socket")?;
socket
.set_read_timeout(Some(timeout))
.context("failed to set stun timeout")?;
let (transaction_id, request) = build_binding_request();
socket
.send_to(&request, server_addr)
.context("failed to send stun request")?;
let mut buf = [0u8; 1024];
let (len, from) = socket.recv_from(&mut buf).context("stun recv failed")?;
if from != server_addr {
return Err(anyhow!("stun response from unexpected address"));
}
parse_binding_response(&buf[..len], &transaction_id)
}
fn resolve_server(server: &str) -> Result<SocketAddr> {
server
.to_socket_addrs()
.context("failed to resolve stun server")?
.next()
.ok_or_else(|| anyhow!("stun server resolution returned no addresses"))
}
fn build_binding_request() -> ([u8; 12], [u8; 20]) {
let mut transaction_id = [0u8; 12];
rand::thread_rng().fill_bytes(&mut transaction_id);
let mut buf = [0u8; 20];
buf[0..2].copy_from_slice(&BINDING_REQUEST.to_be_bytes());
buf[2..4].copy_from_slice(&0u16.to_be_bytes());
buf[4..8].copy_from_slice(&MAGIC_COOKIE.to_be_bytes());
buf[8..20].copy_from_slice(&transaction_id);
(transaction_id, buf)
}
fn parse_binding_response(buf: &[u8], transaction_id: &[u8; 12]) -> Result<SocketAddr> {
if buf.len() < 20 {
return Err(anyhow!("stun response too short"));
}
let msg_type = u16::from_be_bytes([buf[0], buf[1]]);
if msg_type != BINDING_SUCCESS {
return Err(anyhow!("unexpected stun response type {:04x}", msg_type));
}
let msg_len = u16::from_be_bytes([buf[2], buf[3]]) as usize;
if buf.len() < 20 + msg_len {
return Err(anyhow!("stun response length mismatch"));
}
if buf[4..8] != MAGIC_COOKIE.to_be_bytes() {
return Err(anyhow!("stun response missing magic cookie"));
}
if buf[8..20] != transaction_id[..] {
return Err(anyhow!("stun transaction id mismatch"));
}
let mut offset = 20;
let end = 20 + msg_len;
while offset + 4 <= end {
let attr_type = u16::from_be_bytes([buf[offset], buf[offset + 1]]);
let attr_len = u16::from_be_bytes([buf[offset + 2], buf[offset + 3]]) as usize;
offset += 4;
if offset + attr_len > end {
break;
}
let attr = &buf[offset..offset + attr_len];
if attr_type == ATTR_XOR_MAPPED_ADDRESS {
if let Some(addr) = parse_xor_mapped(attr, transaction_id) {
return Ok(addr);
}
} else if attr_type == ATTR_MAPPED_ADDRESS {
if let Some(addr) = parse_mapped(attr) {
return Ok(addr);
}
}
offset += (attr_len + 3) & !3;
}
Err(anyhow!("stun response missing mapped address"))
}
fn parse_mapped(attr: &[u8]) -> Option<SocketAddr> {
if attr.len() < 4 {
return None;
}
let family = attr[1];
let port = u16::from_be_bytes([attr[2], attr[3]]);
match family {
0x01 => {
if attr.len() < 8 {
return None;
}
let addr = Ipv4Addr::new(attr[4], attr[5], attr[6], attr[7]);
Some(SocketAddr::new(IpAddr::V4(addr), port))
}
0x02 => {
if attr.len() < 20 {
return None;
}
let mut octets = [0u8; 16];
octets.copy_from_slice(&attr[4..20]);
Some(SocketAddr::new(IpAddr::V6(Ipv6Addr::from(octets)), port))
}
_ => None,
}
}
fn parse_xor_mapped(attr: &[u8], transaction_id: &[u8; 12]) -> Option<SocketAddr> {
if attr.len() < 4 {
return None;
}
let family = attr[1];
let port = u16::from_be_bytes([attr[2], attr[3]]) ^ ((MAGIC_COOKIE >> 16) as u16);
match family {
0x01 => {
if attr.len() < 8 {
return None;
}
let xaddr = u32::from_be_bytes([attr[4], attr[5], attr[6], attr[7]]) ^ MAGIC_COOKIE;
Some(SocketAddr::new(IpAddr::V4(Ipv4Addr::from(xaddr)), port))
}
0x02 => {
if attr.len() < 20 {
return None;
}
let mut xor = [0u8; 16];
xor[0..4].copy_from_slice(&MAGIC_COOKIE.to_be_bytes());
xor[4..16].copy_from_slice(transaction_id);
let mut addr = [0u8; 16];
for i in 0..16 {
addr[i] = attr[4 + i] ^ xor[i];
}
Some(SocketAddr::new(IpAddr::V6(Ipv6Addr::from(addr)), port))
}
_ => None,
}
}

495
src/turn.rs Normal file
View file

@ -0,0 +1,495 @@
use anyhow::{anyhow, Context, Result};
use hmac::{Hmac, Mac};
use md5::{Digest as Md5Digest, Md5};
use rand::RngCore;
use sha1::Sha1;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs};
use std::time::Duration;
use tokio::net::UdpSocket;
use tokio::time::timeout;
type HmacSha1 = Hmac<Sha1>;
const MAGIC_COOKIE: u32 = 0x2112A442;
const MSG_ALLOCATE_REQUEST: u16 = 0x0003;
const MSG_ALLOCATE_SUCCESS: u16 = 0x0103;
const MSG_ALLOCATE_ERROR: u16 = 0x0113;
const MSG_CREATE_PERMISSION_REQUEST: u16 = 0x0008;
const MSG_CREATE_PERMISSION_SUCCESS: u16 = 0x0108;
const MSG_SEND_INDICATION: u16 = 0x0016;
const MSG_DATA_INDICATION: u16 = 0x0017;
const ATTR_USERNAME: u16 = 0x0006;
const ATTR_REALM: u16 = 0x0014;
const ATTR_NONCE: u16 = 0x0015;
const ATTR_REQUESTED_TRANSPORT: u16 = 0x0019;
const ATTR_XOR_RELAYED_ADDRESS: u16 = 0x0016;
const ATTR_XOR_MAPPED_ADDRESS: u16 = 0x0020;
const ATTR_XOR_PEER_ADDRESS: u16 = 0x0012;
const ATTR_DATA: u16 = 0x0013;
const ATTR_ERROR_CODE: u16 = 0x0009;
const ATTR_MESSAGE_INTEGRITY: u16 = 0x0008;
#[derive(Clone, Debug)]
pub struct TurnCredentials {
pub username: String,
pub password: String,
}
#[derive(Debug)]
pub struct TurnAllocation {
pub socket: UdpSocket,
pub server: SocketAddr,
pub relay_addr: SocketAddr,
#[allow(dead_code)]
pub mapped_addr: Option<SocketAddr>,
username: Option<String>,
realm: Option<String>,
nonce: Option<String>,
key: Option<Vec<u8>>,
}
pub async fn allocate(
server: &str,
creds: Option<&TurnCredentials>,
timeout_duration: Duration,
) -> Result<TurnAllocation> {
let server_addr = resolve_server(server)?;
let bind_addr = match server_addr {
SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
};
let socket = UdpSocket::bind(bind_addr)
.await
.context("failed to bind turn socket")?;
let (transaction_id, request) = build_allocate_request(None, None, None);
socket.send_to(&request, server_addr).await?;
let response = recv_message(&socket, server_addr, timeout_duration).await?;
let parsed = parse_message(&response, Some(&transaction_id))?;
match parsed.msg_type {
MSG_ALLOCATE_SUCCESS => {
let (relay_addr, mapped_addr) = extract_addresses(&parsed, &transaction_id)?;
return Ok(TurnAllocation {
socket,
server: server_addr,
relay_addr,
mapped_addr,
username: None,
realm: None,
nonce: None,
key: None,
});
}
MSG_ALLOCATE_ERROR => {
let error_code = extract_error_code(&parsed);
if error_code == Some(401) || error_code == Some(438) {
let creds = creds.ok_or_else(|| anyhow!("turn auth required"))?;
let realm = extract_string(&parsed, ATTR_REALM)
.ok_or_else(|| anyhow!("turn realm missing"))?;
let nonce = extract_string(&parsed, ATTR_NONCE)
.ok_or_else(|| anyhow!("turn nonce missing"))?;
let key = build_long_term_key(creds, &realm)?;
let (transaction_id, request) = build_allocate_request(
Some(creds.username.as_str()),
Some(realm.as_str()),
Some(nonce.as_str()),
);
let request = add_message_integrity(request, &key)?;
socket.send_to(&request, server_addr).await?;
let response = recv_message(&socket, server_addr, timeout_duration).await?;
let parsed = parse_message(&response, Some(&transaction_id))?;
if parsed.msg_type != MSG_ALLOCATE_SUCCESS {
return Err(anyhow!("turn allocate failed after auth"));
}
let (relay_addr, mapped_addr) = extract_addresses(&parsed, &transaction_id)?;
return Ok(TurnAllocation {
socket,
server: server_addr,
relay_addr,
mapped_addr,
username: Some(creds.username.clone()),
realm: Some(realm),
nonce: Some(nonce),
key: Some(key),
});
}
}
_ => {}
}
Err(anyhow!("turn allocate failed"))
}
pub async fn create_permission(
allocation: &mut TurnAllocation,
peer: SocketAddr,
timeout_duration: Duration,
) -> Result<()> {
let (transaction_id, mut request) = build_create_permission_request(allocation, peer)?;
request = maybe_add_integrity(allocation, request)?;
allocation
.socket
.send_to(&request, allocation.server)
.await?;
let response = recv_message(&allocation.socket, allocation.server, timeout_duration).await?;
let parsed = parse_message(&response, Some(&transaction_id))?;
if parsed.msg_type == MSG_CREATE_PERMISSION_SUCCESS {
return Ok(());
}
Err(anyhow!("turn create permission failed"))
}
pub async fn send_data(
allocation: &mut TurnAllocation,
peer: SocketAddr,
data: &[u8],
) -> Result<()> {
let (transaction_id, request) = build_send_indication(peer, data)?;
let _ = transaction_id;
allocation
.socket
.send_to(&request, allocation.server)
.await?;
Ok(())
}
pub async fn recv_data(
allocation: &mut TurnAllocation,
timeout_duration: Option<Duration>,
) -> Result<Option<(SocketAddr, Vec<u8>)>> {
let mut buf = vec![0u8; 2048];
let result = if let Some(timeout_duration) = timeout_duration {
timeout(timeout_duration, allocation.socket.recv_from(&mut buf)).await?
} else {
allocation.socket.recv_from(&mut buf).await
};
let (len, from) = result?;
if from != allocation.server {
return Ok(None);
}
let parsed = parse_message(&buf[..len], None)?;
if parsed.msg_type != MSG_DATA_INDICATION {
return Ok(None);
}
let transaction_id = parsed.transaction_id;
let peer = extract_xor_address(&parsed, ATTR_XOR_PEER_ADDRESS, &transaction_id)
.ok_or_else(|| anyhow!("turn data indication missing peer address"))?;
let data = extract_bytes(&parsed, ATTR_DATA).unwrap_or_default();
Ok(Some((peer, data)))
}
fn resolve_server(server: &str) -> Result<SocketAddr> {
server
.to_socket_addrs()
.context("failed to resolve turn server")?
.next()
.ok_or_else(|| anyhow!("turn server resolution returned no addresses"))
}
fn random_transaction_id() -> [u8; 12] {
let mut id = [0u8; 12];
rand::thread_rng().fill_bytes(&mut id);
id
}
fn build_allocate_request(
username: Option<&str>,
realm: Option<&str>,
nonce: Option<&str>,
) -> ([u8; 12], Vec<u8>) {
let transaction_id = random_transaction_id();
let mut attrs = Vec::new();
attrs.push(Attribute::new(
ATTR_REQUESTED_TRANSPORT,
vec![17, 0, 0, 0],
));
if let Some(username) = username {
attrs.push(Attribute::new(ATTR_USERNAME, username.as_bytes().to_vec()));
}
if let Some(realm) = realm {
attrs.push(Attribute::new(ATTR_REALM, realm.as_bytes().to_vec()));
}
if let Some(nonce) = nonce {
attrs.push(Attribute::new(ATTR_NONCE, nonce.as_bytes().to_vec()));
}
let msg = build_message(MSG_ALLOCATE_REQUEST, &transaction_id, attrs);
(transaction_id, msg)
}
fn build_create_permission_request(
allocation: &TurnAllocation,
peer: SocketAddr,
) -> Result<([u8; 12], Vec<u8>)> {
let transaction_id = random_transaction_id();
let mut attrs = Vec::new();
let xor_peer = encode_xor_address(peer, &transaction_id)?;
attrs.push(Attribute::new(ATTR_XOR_PEER_ADDRESS, xor_peer));
if let Some(username) = allocation.username.as_ref() {
attrs.push(Attribute::new(ATTR_USERNAME, username.as_bytes().to_vec()));
}
if let Some(realm) = allocation.realm.as_ref() {
attrs.push(Attribute::new(ATTR_REALM, realm.as_bytes().to_vec()));
}
if let Some(nonce) = allocation.nonce.as_ref() {
attrs.push(Attribute::new(ATTR_NONCE, nonce.as_bytes().to_vec()));
}
let msg = build_message(MSG_CREATE_PERMISSION_REQUEST, &transaction_id, attrs);
Ok((transaction_id, msg))
}
fn build_send_indication(peer: SocketAddr, data: &[u8]) -> Result<([u8; 12], Vec<u8>)> {
let transaction_id = random_transaction_id();
let xor_peer = encode_xor_address(peer, &transaction_id)?;
let attrs = vec![
Attribute::new(ATTR_XOR_PEER_ADDRESS, xor_peer),
Attribute::new(ATTR_DATA, data.to_vec()),
];
let msg = build_message(MSG_SEND_INDICATION, &transaction_id, attrs);
Ok((transaction_id, msg))
}
fn build_message(msg_type: u16, transaction_id: &[u8; 12], attrs: Vec<Attribute>) -> Vec<u8> {
let mut body = Vec::new();
for attr in attrs {
attr.write(&mut body);
}
let length = body.len() as u16;
let mut buf = Vec::with_capacity(20 + body.len());
buf.extend_from_slice(&msg_type.to_be_bytes());
buf.extend_from_slice(&length.to_be_bytes());
buf.extend_from_slice(&MAGIC_COOKIE.to_be_bytes());
buf.extend_from_slice(transaction_id);
buf.extend_from_slice(&body);
buf
}
fn add_message_integrity(mut msg: Vec<u8>, key: &[u8]) -> Result<Vec<u8>> {
let current_len = u16::from_be_bytes([msg[2], msg[3]]);
let total_len = current_len.saturating_add(24);
msg[2..4].copy_from_slice(&total_len.to_be_bytes());
let mi_offset = msg.len() + 4;
msg.extend_from_slice(&ATTR_MESSAGE_INTEGRITY.to_be_bytes());
msg.extend_from_slice(&(20u16).to_be_bytes());
msg.extend_from_slice(&vec![0u8; 20]);
let mut mac = HmacSha1::new_from_slice(key).map_err(|_| anyhow!("invalid hmac key"))?;
mac.update(&msg);
let result = mac.finalize().into_bytes();
msg[mi_offset..mi_offset + 20].copy_from_slice(&result);
Ok(msg)
}
fn maybe_add_integrity(allocation: &TurnAllocation, msg: Vec<u8>) -> Result<Vec<u8>> {
if let Some(key) = allocation.key.as_ref() {
add_message_integrity(msg, key)
} else {
Ok(msg)
}
}
fn build_long_term_key(creds: &TurnCredentials, realm: &str) -> Result<Vec<u8>> {
let mut hasher = Md5::new();
let data = format!("{}:{}:{}", creds.username, realm, creds.password);
hasher.update(data.as_bytes());
Ok(hasher.finalize().to_vec())
}
async fn recv_message(
socket: &UdpSocket,
server: SocketAddr,
timeout_duration: Duration,
) -> Result<Vec<u8>> {
let mut buf = vec![0u8; 2048];
let (len, from) = timeout(timeout_duration, socket.recv_from(&mut buf)).await??;
if from != server {
return Err(anyhow!("unexpected turn response source"));
}
buf.truncate(len);
Ok(buf)
}
#[derive(Clone)]
struct Attribute {
ty: u16,
value: Vec<u8>,
}
impl Attribute {
fn new(ty: u16, value: Vec<u8>) -> Self {
Self { ty, value }
}
fn write(&self, buf: &mut Vec<u8>) {
buf.extend_from_slice(&self.ty.to_be_bytes());
buf.extend_from_slice(&(self.value.len() as u16).to_be_bytes());
buf.extend_from_slice(&self.value);
let padding = (4 - (self.value.len() % 4)) % 4;
if padding > 0 {
buf.extend_from_slice(&vec![0u8; padding]);
}
}
}
struct ParsedMessage {
msg_type: u16,
transaction_id: [u8; 12],
attrs: Vec<Attribute>,
}
fn parse_message(buf: &[u8], expected_id: Option<&[u8; 12]>) -> Result<ParsedMessage> {
if buf.len() < 20 {
return Err(anyhow!("turn message too short"));
}
let msg_type = u16::from_be_bytes([buf[0], buf[1]]);
let length = u16::from_be_bytes([buf[2], buf[3]]) as usize;
if buf[4..8] != MAGIC_COOKIE.to_be_bytes() {
return Err(anyhow!("turn message missing magic cookie"));
}
let mut transaction_id = [0u8; 12];
transaction_id.copy_from_slice(&buf[8..20]);
if let Some(expected) = expected_id {
if &transaction_id != expected {
return Err(anyhow!("turn transaction id mismatch"));
}
}
if buf.len() < 20 + length {
return Err(anyhow!("turn message length mismatch"));
}
let mut attrs = Vec::new();
let mut offset = 20;
let end = 20 + length;
while offset + 4 <= end {
let ty = u16::from_be_bytes([buf[offset], buf[offset + 1]]);
let len = u16::from_be_bytes([buf[offset + 2], buf[offset + 3]]) as usize;
offset += 4;
if offset + len > end {
break;
}
let value = buf[offset..offset + len].to_vec();
attrs.push(Attribute { ty, value });
offset += (len + 3) & !3;
}
Ok(ParsedMessage {
msg_type,
transaction_id,
attrs,
})
}
fn extract_error_code(parsed: &ParsedMessage) -> Option<u16> {
for attr in &parsed.attrs {
if attr.ty != ATTR_ERROR_CODE || attr.value.len() < 4 {
continue;
}
let class = attr.value[2] & 0x07;
let number = attr.value[3];
return Some((class as u16) * 100 + number as u16);
}
None
}
fn extract_string(parsed: &ParsedMessage, attr_type: u16) -> Option<String> {
for attr in &parsed.attrs {
if attr.ty == attr_type {
return String::from_utf8(attr.value.clone()).ok();
}
}
None
}
fn extract_bytes(parsed: &ParsedMessage, attr_type: u16) -> Option<Vec<u8>> {
for attr in &parsed.attrs {
if attr.ty == attr_type {
return Some(attr.value.clone());
}
}
None
}
fn extract_addresses(
parsed: &ParsedMessage,
transaction_id: &[u8; 12],
) -> Result<(SocketAddr, Option<SocketAddr>)> {
let relay = extract_xor_address(parsed, ATTR_XOR_RELAYED_ADDRESS, transaction_id)
.ok_or_else(|| anyhow!("turn allocate missing relay address"))?;
let mapped = extract_xor_address(parsed, ATTR_XOR_MAPPED_ADDRESS, transaction_id);
Ok((relay, mapped))
}
fn extract_xor_address(
parsed: &ParsedMessage,
attr_type: u16,
transaction_id: &[u8; 12],
) -> Option<SocketAddr> {
for attr in &parsed.attrs {
if attr.ty != attr_type {
continue;
}
if let Some(addr) = decode_xor_address(&attr.value, transaction_id) {
return Some(addr);
}
}
None
}
fn decode_xor_address(value: &[u8], transaction_id: &[u8; 12]) -> Option<SocketAddr> {
if value.len() < 4 {
return None;
}
let family = value[1];
let port = u16::from_be_bytes([value[2], value[3]]) ^ ((MAGIC_COOKIE >> 16) as u16);
match family {
0x01 => {
if value.len() < 8 {
return None;
}
let xaddr = u32::from_be_bytes([value[4], value[5], value[6], value[7]]) ^ MAGIC_COOKIE;
Some(SocketAddr::new(IpAddr::V4(Ipv4Addr::from(xaddr)), port))
}
0x02 => {
if value.len() < 20 {
return None;
}
let mut xor = [0u8; 16];
xor[0..4].copy_from_slice(&MAGIC_COOKIE.to_be_bytes());
xor[4..16].copy_from_slice(transaction_id);
let mut addr = [0u8; 16];
for i in 0..16 {
addr[i] = value[4 + i] ^ xor[i];
}
Some(SocketAddr::new(IpAddr::V6(Ipv6Addr::from(addr)), port))
}
_ => None,
}
}
fn encode_xor_address(addr: SocketAddr, transaction_id: &[u8; 12]) -> Result<Vec<u8>> {
let mut value = Vec::new();
value.push(0);
match addr {
SocketAddr::V4(addr) => {
value.push(0x01);
let port = addr.port() ^ ((MAGIC_COOKIE >> 16) as u16);
value.extend_from_slice(&port.to_be_bytes());
let xaddr = u32::from(*addr.ip()) ^ MAGIC_COOKIE;
value.extend_from_slice(&xaddr.to_be_bytes());
}
SocketAddr::V6(addr) => {
value.push(0x02);
let port = addr.port() ^ ((MAGIC_COOKIE >> 16) as u16);
value.extend_from_slice(&port.to_be_bytes());
let mut xor = [0u8; 16];
xor[0..4].copy_from_slice(&MAGIC_COOKIE.to_be_bytes());
xor[4..16].copy_from_slice(transaction_id);
let mut ip_bytes = addr.ip().octets();
for i in 0..16 {
ip_bytes[i] ^= xor[i];
}
value.extend_from_slice(&ip_bytes);
}
}
Ok(value)
}

73
src/udp_relay.rs Normal file
View file

@ -0,0 +1,73 @@
use anyhow::{anyhow, Result};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs};
const MAGIC: &[u8; 4] = b"LSR1";
const TYPE_REGISTER: u8 = 1;
const TYPE_SEND: u8 = 2;
const TYPE_DELIVER: u8 = 3;
const HEADER_LEN: usize = 8;
const MAX_ID_LEN: usize = 64;
pub fn build_register(node_id: &str) -> Result<Vec<u8>> {
build_packet(TYPE_REGISTER, node_id, "", &[])
}
pub fn build_send(from_id: &str, to_id: &str, payload: &[u8]) -> Result<Vec<u8>> {
build_packet(TYPE_SEND, from_id, to_id, payload)
}
pub fn parse_deliver(buf: &[u8]) -> Option<(String, Vec<u8>)> {
if buf.len() < HEADER_LEN {
return None;
}
if &buf[0..4] != MAGIC {
return None;
}
let msg_type = buf[4];
if msg_type != TYPE_DELIVER {
return None;
}
let from_len = buf[5] as usize;
let to_len = buf[6] as usize;
if from_len > MAX_ID_LEN || to_len > MAX_ID_LEN || to_len != 0 {
return None;
}
let offset = HEADER_LEN;
if buf.len() < offset + from_len + to_len {
return None;
}
let from_end = offset + from_len;
let from_id = std::str::from_utf8(&buf[offset..from_end]).ok()?.to_string();
let payload = buf[from_end..].to_vec();
Some((from_id, payload))
}
pub fn resolve_server(server: &str) -> Result<SocketAddr> {
server
.to_socket_addrs()?
.next()
.ok_or_else(|| anyhow!("relay server resolution returned no addresses"))
}
pub fn bind_addr_for(server: &SocketAddr) -> SocketAddr {
match server {
SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
}
}
fn build_packet(msg_type: u8, from_id: &str, to_id: &str, payload: &[u8]) -> Result<Vec<u8>> {
if from_id.len() > MAX_ID_LEN || to_id.len() > MAX_ID_LEN {
return Err(anyhow!("relay id too long"));
}
let mut buf = Vec::with_capacity(HEADER_LEN + from_id.len() + to_id.len() + payload.len());
buf.extend_from_slice(MAGIC);
buf.push(msg_type);
buf.push(from_id.len() as u8);
buf.push(to_id.len() as u8);
buf.push(0);
buf.extend_from_slice(from_id.as_bytes());
buf.extend_from_slice(to_id.as_bytes());
buf.extend_from_slice(payload);
Ok(buf)
}

529
src/wg.rs Normal file
View file

@ -0,0 +1,529 @@
use crate::model::{NetMap, Route, RouteKind};
use crate::netlink::Netlink;
use crate::routes;
use crate::state::ClientState;
use anyhow::{anyhow, Context, Result};
use boringtun::device::{DeviceConfig, DeviceHandle};
use ipnet::IpNet;
use std::collections::{HashMap, HashSet};
use std::net::{IpAddr, SocketAddr};
use std::path::{Path, PathBuf};
use std::sync::{Mutex, OnceLock};
use std::time::{Duration, Instant};
use tokio::time::sleep;
use wireguard_control::{
Backend as WgBackend, Device, DeviceUpdate, InterfaceName, Key, PeerConfigBuilder,
};
#[derive(Clone, Copy, Debug)]
pub enum Backend {
Kernel,
Boringtun,
}
pub struct WgConfig {
pub interface: String,
pub listen_port: u16,
pub backend: Backend,
}
static BORINGTUN_HANDLES: OnceLock<Mutex<HashMap<String, DeviceHandle>>> = OnceLock::new();
#[derive(Default)]
pub struct EndpointTracker {
peers: HashMap<String, PeerEndpointState>,
}
#[derive(Default)]
struct PeerEndpointState {
next_index: usize,
rotations: usize,
relay_active: bool,
}
pub async fn apply(
netmap: &NetMap,
state: &ClientState,
cfg: &WgConfig,
routes_cfg: Option<&routes::RouteApplyConfig>,
) -> Result<()> {
let netlink = Netlink::new().await?;
let index = match cfg.backend {
Backend::Kernel => apply_kernel(netmap, state, cfg, routes_cfg, &netlink).await?,
Backend::Boringtun => apply_boringtun(netmap, state, cfg, routes_cfg, &netlink).await?,
};
add_peer_routes(netmap, index, &netlink).await?;
Ok(())
}
pub async fn remove(interface: &str, backend: Backend) -> Result<()> {
let netlink = Netlink::new().await?;
match backend {
Backend::Kernel => {
netlink.delete_link(interface).await?;
}
Backend::Boringtun => {
stop_boringtun(interface);
let socket_path = userspace_socket_path(interface);
let _ = std::fs::remove_file(&socket_path);
netlink.delete_link(interface).await?;
}
}
Ok(())
}
pub fn probe_peers(netmap: &NetMap, timeout_seconds: u64) -> Result<()> {
let mut v4_socket: Option<std::net::UdpSocket> = None;
let mut v6_socket: Option<std::net::UdpSocket> = None;
for peer in &netmap.peers {
let mut probed = false;
for endpoint in &peer.endpoints {
match endpoint.parse::<SocketAddr>() {
Ok(addr) => {
probe_addr(&mut v4_socket, &mut v6_socket, addr, timeout_seconds);
probed = true;
}
Err(_) => {
eprintln!("probe skipped invalid endpoint {} for {}", endpoint, peer.id);
}
}
}
if !probed {
probe_ip(&mut v4_socket, &mut v6_socket, &peer.ipv4, timeout_seconds);
probe_ip(&mut v4_socket, &mut v6_socket, &peer.ipv6, timeout_seconds);
}
}
Ok(())
}
fn backend_for(backend: Backend) -> WgBackend {
match backend {
Backend::Kernel => WgBackend::Kernel,
Backend::Boringtun => WgBackend::Userspace,
}
}
pub fn refresh_peer_endpoints(
netmap: &NetMap,
cfg: &WgConfig,
tracker: &mut EndpointTracker,
relay_endpoints: &HashMap<String, SocketAddr>,
stale_after: Duration,
max_rotations: usize,
) -> Result<()> {
let iface: InterfaceName = cfg
.interface
.parse()
.context("invalid interface name")?;
let backend = backend_for(cfg.backend);
let device = Device::get(&iface, backend).context("wireguard device query failed")?;
let mut peer_info = HashMap::new();
for info in device.peers {
peer_info.insert(info.config.public_key.to_base64(), info);
}
let max_rotations = max_rotations.max(1);
let mut update = DeviceUpdate::new();
let mut changed = false;
let mut desired_endpoints: HashMap<String, SocketAddr> = HashMap::new();
for peer in &netmap.peers {
let endpoints: Vec<SocketAddr> = peer
.endpoints
.iter()
.filter_map(|endpoint| endpoint.parse().ok())
.collect();
let info = peer_info.get(&peer.wg_public_key);
let handshake_stale = match info.and_then(|info| info.stats.last_handshake_time) {
Some(ts) => ts.elapsed().map(|age| age > stale_after).unwrap_or(true),
None => true,
};
if !handshake_stale {
let state = tracker.peers.entry(peer.id.clone()).or_default();
state.rotations = 0;
state.relay_active = false;
continue;
}
let state = tracker.peers.entry(peer.id.clone()).or_default();
let current_endpoint = info.and_then(|info| info.config.endpoint);
let mut desired_endpoint = None;
if state.relay_active {
if let Some(relay) = relay_endpoints.get(&peer.id) {
desired_endpoint = Some(*relay);
}
} else if !endpoints.is_empty() {
let idx = state.next_index % endpoints.len();
desired_endpoint = Some(endpoints[idx]);
state.next_index = (idx + 1) % endpoints.len();
state.rotations = state.rotations.saturating_add(1);
if state.rotations >= max_rotations && relay_endpoints.contains_key(&peer.id) {
state.relay_active = true;
state.rotations = 0;
}
} else if let Some(relay) = relay_endpoints.get(&peer.id) {
state.relay_active = true;
state.rotations = 0;
desired_endpoint = Some(*relay);
}
if let Some(desired) = desired_endpoint {
if Some(desired) != current_endpoint {
changed = true;
if backend == WgBackend::Userspace {
desired_endpoints.insert(peer.id.clone(), desired);
} else {
let peer_key = Key::from_base64(&peer.wg_public_key)
.with_context(|| format!("invalid peer public key {}", peer.id))?;
update = update.add_peer(
PeerConfigBuilder::new(&peer_key)
.set_endpoint(desired)
.set_persistent_keepalive_interval(25),
);
}
}
}
}
if changed {
if backend == WgBackend::Userspace {
let mut full_update = DeviceUpdate::new().replace_peers();
for peer in &netmap.peers {
let info = peer_info.get(&peer.wg_public_key);
let mut builder = if let Some(info) = info {
PeerConfigBuilder::from_peer_config(info.config.clone())
} else {
build_peer_builder_from_netmap(peer)?
};
if let Some(desired) = desired_endpoints.get(&peer.id) {
builder = builder
.set_endpoint(*desired)
.set_persistent_keepalive_interval(25);
}
full_update = full_update.add_peer(builder);
}
full_update
.apply(&iface, backend)
.context("wireguard endpoint refresh failed")?;
} else {
update
.apply(&iface, backend)
.context("wireguard endpoint refresh failed")?;
}
}
Ok(())
}
async fn apply_kernel(
netmap: &NetMap,
state: &ClientState,
cfg: &WgConfig,
routes_cfg: Option<&routes::RouteApplyConfig>,
netlink: &Netlink,
) -> Result<u32> {
apply_wireguard_config(netmap, state, cfg, routes_cfg, WgBackend::Kernel)?;
let index = netlink
.wait_for_link(&cfg.interface, Duration::from_secs(3))
.await?;
configure_addresses(netlink, index, state).await?;
netlink.set_link_up(index).await?;
Ok(index)
}
async fn apply_boringtun(
netmap: &NetMap,
state: &ClientState,
cfg: &WgConfig,
routes_cfg: Option<&routes::RouteApplyConfig>,
netlink: &Netlink,
) -> Result<u32> {
ensure_boringtun(&cfg.interface)?;
wait_for_userspace_socket(&cfg.interface, Duration::from_secs(3)).await?;
apply_wireguard_config(netmap, state, cfg, routes_cfg, WgBackend::Userspace)?;
let index = netlink
.wait_for_link(&cfg.interface, Duration::from_secs(3))
.await?;
configure_addresses(netlink, index, state).await?;
netlink.set_link_up(index).await?;
Ok(index)
}
async fn configure_addresses(
netlink: &Netlink,
index: u32,
state: &ClientState,
) -> Result<()> {
let ipv4 = parse_ip(&state.ipv4, "ipv4")?;
netlink.replace_address(index, ipv4, 32).await?;
let ipv6 = parse_ip(&state.ipv6, "ipv6")?;
netlink.replace_address(index, ipv6, 128).await?;
Ok(())
}
fn apply_wireguard_config(
netmap: &NetMap,
state: &ClientState,
cfg: &WgConfig,
routes_cfg: Option<&routes::RouteApplyConfig>,
backend: WgBackend,
) -> Result<()> {
let iface: InterfaceName = cfg
.interface
.parse()
.context("invalid interface name")?;
let update = build_device_update(netmap, state, cfg, routes_cfg)?;
update
.apply(&iface, backend)
.context("wireguard config apply failed")?;
Ok(())
}
fn build_device_update(
netmap: &NetMap,
state: &ClientState,
cfg: &WgConfig,
routes_cfg: Option<&routes::RouteApplyConfig>,
) -> Result<DeviceUpdate> {
let private_key = Key::from_base64(&state.wg_private_key)
.context("invalid wireguard private key")?;
let mut update = DeviceUpdate::new()
.set_private_key(private_key)
.set_listen_port(cfg.listen_port)
.replace_peers();
let selected_exit_ids = routes_cfg
.map(|cfg| routes::selected_exit_peer_ids(netmap, cfg))
.unwrap_or_default();
for peer in &netmap.peers {
let peer_key =
Key::from_base64(&peer.wg_public_key).context("invalid peer public key")?;
let ipv4: IpAddr = peer.ipv4.parse().context("invalid peer ipv4")?;
let ipv6: IpAddr = peer.ipv6.parse().context("invalid peer ipv6")?;
let mut allowed: HashSet<IpNet> = HashSet::new();
allowed.insert(IpNet::new(ipv4, 32).context("invalid peer ipv4 prefix")?);
allowed.insert(IpNet::new(ipv6, 128).context("invalid peer ipv6 prefix")?);
let allow_exit = selected_exit_ids.contains(&peer.id);
for route in &peer.routes {
if !route.enabled {
continue;
}
let net = match route_allowed_prefix(route) {
Ok(net) => net,
Err(err) => {
eprintln!(
"skipping allowed ip for route {} on peer {}: {}",
route.prefix, peer.id, err
);
continue;
}
};
match route.kind {
RouteKind::Subnet => {
allowed.insert(net);
}
RouteKind::Exit => {
if allow_exit {
allowed.insert(net);
}
}
}
}
let mut builder = PeerConfigBuilder::new(&peer_key).replace_allowed_ips();
for net in allowed {
builder = add_allowed_ip(builder, net);
}
if let Some(addr) = peer
.endpoints
.iter()
.find_map(|endpoint| endpoint.parse::<SocketAddr>().ok())
{
builder = builder
.set_endpoint(addr)
.set_persistent_keepalive_interval(25);
} else if !peer.endpoints.is_empty() {
eprintln!("no valid endpoint for peer {}", peer.id);
}
update = update.add_peer(builder);
}
Ok(update)
}
async fn add_peer_routes(netmap: &NetMap, index: u32, netlink: &Netlink) -> Result<()> {
for peer in &netmap.peers {
let ipv4: IpAddr = peer.ipv4.parse().context("invalid peer ipv4")?;
let ipv6: IpAddr = peer.ipv6.parse().context("invalid peer ipv6")?;
let v4 = IpNet::new(ipv4, 32).context("invalid peer ipv4 prefix")?;
let v6 = IpNet::new(ipv6, 128).context("invalid peer ipv6 prefix")?;
netlink.replace_route(v4, index).await?;
netlink.replace_route(v6, index).await?;
}
Ok(())
}
fn route_allowed_prefix(route: &Route) -> Result<IpNet> {
let Some(mapped) = route.mapped_prefix.as_deref() else {
return route.prefix.parse().context("invalid route prefix");
};
let real_net: IpNet = route.prefix.parse().context("invalid route prefix")?;
let mapped_net: IpNet = mapped.parse().context("invalid mapped prefix")?;
let real_v4 = matches!(real_net, IpNet::V4(_));
let mapped_v4 = matches!(mapped_net, IpNet::V4(_));
if real_v4 != mapped_v4 {
return Err(anyhow!("mapped prefix ip version mismatch"));
}
if real_net.prefix_len() != mapped_net.prefix_len() {
return Err(anyhow!("mapped prefix length mismatch"));
}
Ok(mapped_net)
}
fn add_allowed_ip(builder: PeerConfigBuilder, net: IpNet) -> PeerConfigBuilder {
match net {
IpNet::V4(v4) => builder.add_allowed_ip(IpAddr::V4(v4.network()), v4.prefix_len()),
IpNet::V6(v6) => builder.add_allowed_ip(IpAddr::V6(v6.network()), v6.prefix_len()),
}
}
fn build_peer_builder_from_netmap(peer: &crate::model::PeerInfo) -> Result<PeerConfigBuilder> {
let peer_key = Key::from_base64(&peer.wg_public_key)
.with_context(|| format!("invalid peer public key {}", peer.id))?;
let ipv4: IpAddr = peer.ipv4.parse().context("invalid peer ipv4")?;
let ipv6: IpAddr = peer.ipv6.parse().context("invalid peer ipv6")?;
let mut allowed: HashSet<IpNet> = HashSet::new();
allowed.insert(IpNet::new(ipv4, 32).context("invalid peer ipv4 prefix")?);
allowed.insert(IpNet::new(ipv6, 128).context("invalid peer ipv6 prefix")?);
for route in &peer.routes {
if !route.enabled {
continue;
}
if let RouteKind::Subnet = route.kind {
if let Ok(net) = route_allowed_prefix(route) {
allowed.insert(net);
}
}
}
let mut builder = PeerConfigBuilder::new(&peer_key).replace_allowed_ips();
for net in allowed {
builder = add_allowed_ip(builder, net);
}
if let Some(addr) = peer
.endpoints
.iter()
.find_map(|endpoint| endpoint.parse::<SocketAddr>().ok())
{
builder = builder
.set_endpoint(addr)
.set_persistent_keepalive_interval(25);
}
Ok(builder)
}
fn ensure_boringtun(interface: &str) -> Result<()> {
let handles = BORINGTUN_HANDLES.get_or_init(|| Mutex::new(HashMap::new()));
let mut map = handles.lock().unwrap();
if map.contains_key(interface) {
return Ok(());
}
let config = DeviceConfig::default();
let handle = DeviceHandle::new(interface, config).context("boringtun init failed")?;
map.insert(interface.to_string(), handle);
Ok(())
}
fn stop_boringtun(interface: &str) {
if let Some(handles) = BORINGTUN_HANDLES.get() {
let mut map = handles.lock().unwrap();
map.remove(interface);
}
}
async fn wait_for_userspace_socket(interface: &str, timeout: Duration) -> Result<()> {
let start = Instant::now();
let path = userspace_socket_path(interface);
loop {
if path.exists() {
return Ok(());
}
if start.elapsed() > timeout {
return Err(anyhow!("userspace wg socket {} did not appear", path.display()));
}
sleep(Duration::from_millis(100)).await;
}
}
fn userspace_socket_path(interface: &str) -> PathBuf {
Path::new("/var/run/wireguard").join(format!("{interface}.sock"))
}
fn parse_ip(address: &str, label: &str) -> Result<IpAddr> {
let ip: IpAddr = address.parse().with_context(|| format!("invalid {}", label))?;
match (label, ip) {
("ipv4", IpAddr::V4(_)) => Ok(ip),
("ipv6", IpAddr::V6(_)) => Ok(ip),
_ => Err(anyhow!("unexpected {} address: {}", label, address)),
}
}
fn probe_ip(
v4_socket: &mut Option<std::net::UdpSocket>,
v6_socket: &mut Option<std::net::UdpSocket>,
address: &str,
timeout_seconds: u64,
) {
let ip: IpAddr = match address.parse() {
Ok(ip) => ip,
Err(_) => {
eprintln!("probe failed for {} (invalid address)", address);
return;
}
};
let target = std::net::SocketAddr::new(ip, 9);
probe_addr(v4_socket, v6_socket, target, timeout_seconds);
}
fn probe_addr(
v4_socket: &mut Option<std::net::UdpSocket>,
v6_socket: &mut Option<std::net::UdpSocket>,
target: SocketAddr,
timeout_seconds: u64,
) {
let socket = match target {
SocketAddr::V4(_) => {
if v4_socket.is_none() {
match std::net::UdpSocket::bind("0.0.0.0:0") {
Ok(sock) => {
let _ = sock.set_write_timeout(Some(Duration::from_secs(timeout_seconds.max(1))));
*v4_socket = Some(sock);
}
Err(_) => {
eprintln!("probe failed for {} (udp bind)", target);
return;
}
}
}
v4_socket.as_ref().unwrap()
}
SocketAddr::V6(_) => {
if v6_socket.is_none() {
match std::net::UdpSocket::bind("[::]:0") {
Ok(sock) => {
let _ = sock.set_write_timeout(Some(Duration::from_secs(timeout_seconds.max(1))));
*v6_socket = Some(sock);
}
Err(_) => {
eprintln!("probe failed for {} (udp bind)", target);
return;
}
}
}
v6_socket.as_ref().unwrap()
}
};
if socket.send_to(b"lightscale-probe", target).is_err() {
eprintln!("probe failed for {} (udp send)", target);
}
}