move proxies to sub-processes

This commit is contained in:
Robin Appelman 2025-11-12 21:01:46 +01:00
commit e672e11f09
9 changed files with 217 additions and 199 deletions

View file

@ -3,6 +3,8 @@ use serde::{Deserialize, Deserializer};
use std::ffi::OsString;
use std::fmt::{Display, Formatter};
use std::path::Path;
use std::str::FromStr;
use thiserror::Error;
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct NamespaceName(String);
@ -38,6 +40,12 @@ impl AsRef<Path> for NamespaceName {
}
}
impl PartialEq<&str> for NamespaceName {
fn eq(&self, other: &&str) -> bool {
self.0 == *other
}
}
impl From<NamespaceName> for String {
fn from(value: NamespaceName) -> Self {
value.0
@ -83,6 +91,23 @@ impl<'de> Deserialize<'de> for NamespaceName {
}
}
impl FromStr for NamespaceName {
type Err = InvalidNamespaceNameError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if !validate_name(s) {
return Err(InvalidNamespaceNameError { name: s.into() });
}
Ok(NamespaceName(s.into()))
}
}
#[derive(Debug, Error)]
#[error("invalid name for namespace: '{name}'")]
pub struct InvalidNamespaceNameError {
name: String,
}
/// Check if a name follows the portable filename character set
fn validate_name(name: &str) -> bool {
if name.is_empty() {

View file

@ -4,6 +4,7 @@ use std::fmt::{Display, Formatter};
use std::net::{IpAddr, SocketAddr};
use std::path::PathBuf;
use std::str::FromStr;
use thiserror::Error;
#[derive(Debug, PartialEq, Clone, Hash, Eq)]
pub enum ForwardSource {
@ -77,17 +78,8 @@ impl<'de> Deserialize<'de> for ForwardSource {
where
E: Error,
{
if v.starts_with('/') {
Ok(ForwardSource::Unix(v.into()))
} else {
if let Ok(port) = u16::from_str(v) {
return self.visit_u16(port);
}
let addr = v
.parse()
.map_err(|_| E::invalid_value(Unexpected::Str(v), &self))?;
Ok(ForwardSource::Ip(addr))
}
v.parse()
.map_err(|_| E::invalid_value(Unexpected::Str(v), &self))
}
}
@ -95,6 +87,31 @@ impl<'de> Deserialize<'de> for ForwardSource {
}
}
impl FromStr for ForwardSource {
type Err = InvalidForwardSource;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.starts_with('/') {
Ok(ForwardSource::Unix(s.into()))
} else {
if let Ok(port) = u16::from_str(s) {
let ip = IpAddr::from([0, 0, 0, 0]);
return Ok(ForwardSource::Ip(SocketAddr::from((ip, port))));
}
let addr = s
.parse()
.map_err(|_| InvalidForwardSource { forward: s.into() })?;
Ok(ForwardSource::Ip(addr))
}
}
}
#[derive(Debug, Error)]
#[error("forward source '{forward}' is not a valid unix path or socket address")]
pub struct InvalidForwardSource {
forward: String,
}
#[test]
fn test_de() {
use serde_test::{assert_de_tokens, assert_de_tokens_error, Token};

View file

@ -3,6 +3,7 @@ use serde::{Deserialize, Deserializer};
use std::fmt::{Display, Formatter};
use std::net::{IpAddr, SocketAddr};
use std::str::FromStr;
use thiserror::Error;
#[derive(Debug, PartialEq, Clone, Hash, Eq)]
pub struct ForwardTarget {
@ -70,13 +71,8 @@ impl<'de> Deserialize<'de> for ForwardTarget {
where
E: Error,
{
if let Ok(port) = u16::from_str(v) {
return self.visit_u16(port);
}
let addr = v
.parse()
.map_err(|_| E::invalid_value(Unexpected::Str(v), &self))?;
Ok(ForwardTarget { addr })
v.parse()
.map_err(|_| E::invalid_value(Unexpected::Str(v), &self))
}
}
@ -84,6 +80,29 @@ impl<'de> Deserialize<'de> for ForwardTarget {
}
}
impl FromStr for ForwardTarget {
type Err = InvalidForwardTarget;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Ok(port) = u16::from_str(s) {
let ip = IpAddr::from([127, 0, 0, 1]);
return Ok(ForwardTarget {
addr: SocketAddr::from((ip, port)),
});
}
let addr = s
.parse()
.map_err(|_| InvalidForwardTarget { forward: s.into() })?;
Ok(ForwardTarget { addr })
}
}
#[derive(Debug, Error)]
#[error("forward source '{forward}' is not a valid socket address")]
pub struct InvalidForwardTarget {
forward: String,
}
#[test]
fn test_de() {
use serde_test::{assert_de_tokens, assert_de_tokens_error, Token};

View file

@ -4,7 +4,6 @@ use crate::proxy::{ActiveProxy, ProxyError};
use futures::FutureExt;
use futures::StreamExt;
use futures_concurrency::stream::Merge;
use humansize::{BINARY, SizeFormatter};
use main_error::MainResult;
use sd_notify::{NotifyState, notify};
use std::io::Error as IoError;
@ -78,13 +77,9 @@ async fn daemon_async(mut config: Config) -> Result<(), DaemonError> {
println!("{}:", namespace.name());
for proxy in &namespace.proxies {
println!(
" {} => {} {} connections ({} active), {} sent to namespace, {} received from namespace",
" {} => {}",
proxy.source,
proxy.destination,
proxy.stats.total_connections(),
proxy.stats.open_connections(),
SizeFormatter::new(proxy.stats.written(), BINARY),
SizeFormatter::new(proxy.stats.read(), BINARY),
);
}
}

View file

@ -1,9 +1,10 @@
use std::path::{PathBuf};
use clap::{Parser, Subcommand};
use main_error::MainResult;
use crate::config::Config;
use crate::config::{Config, ForwardSource, ForwardTarget, NamespaceName};
use crate::daemon::daemon;
use crate::down::down;
use crate::proxy::proxy;
use crate::up::up;
mod config;
@ -38,6 +39,17 @@ enum Commands {
Down,
/// Signal a running daemon to reload it's configuration
Reload,
/// Create the configured namespaces
Proxy {
/// Namespace to proxy connections from, use "parent" to listen in the namespace this command was spawned into
source_namespace: NamespaceName,
/// Source address to listen on
source: ForwardSource,
/// Namespace to proxy connections to
target_namespace: NamespaceName,
/// Target address to connect to
target: ForwardTarget,
},
}
fn main() -> MainResult {
@ -56,7 +68,10 @@ fn main() -> MainResult {
Commands::Down => {
down()
}
Commands::Reload => reload()
Commands::Reload => reload(),
Commands::Proxy {source, target, source_namespace, target_namespace} => {
proxy(source_namespace, target_namespace, source, target)
},
}
}

View file

@ -3,17 +3,19 @@ mod tcp;
use crate::config::{ForwardConfig, ForwardSource, ForwardTarget, NamespaceName};
use crate::proxy::tcp::Proxy;
use futures::future::AbortHandle;
use main_error::MainResult;
use nix::sched::{CloneFlags, setns};
use nix::sys::signal::{SIGINT, kill};
use nix::unistd::{Gid, Pid, Uid, setgid, setuid};
use std::fs::{File, remove_file};
use std::io::Error as IoError;
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::process::{Child, Command};
use std::thread::spawn;
use syscalls::{Sysno, syscall};
use thiserror::Error;
use tokio::runtime::Builder;
use tokio::signal::ctrl_c;
use tracing::error;
use uzers::{get_group_by_name, get_user_by_name};
@ -33,13 +35,14 @@ pub enum ProxyError {
},
#[error("Failed to open namespace file {}: {error:#}", path.display())]
OpenNamespace { path: PathBuf, error: IoError },
#[error("Failed to spawn proxy: {error:#}")]
Spawn { error: std::io::Error },
}
pub struct ActiveProxy {
pub source: ForwardSource,
pub destination: ForwardTarget,
abort: AbortHandle,
pub stats: ProxyStats,
child: Option<Child>,
}
impl ActiveProxy {
@ -47,89 +50,52 @@ impl ActiveProxy {
config: &ForwardConfig,
namespace: &NamespaceName,
) -> Result<ActiveProxy, ProxyError> {
let stats = ProxyStats::default();
let (abort, abort_reg) = AbortHandle::new_pair();
let destination = config.target.clone();
let run_stats = stats.clone();
let ns_handle = open_namespace(format!("/var/run/netns/{namespace}"))?;
let self_ns_handle = open_namespace("/proc/self/ns/net")?;
let (listen_namespace, target_namespace) = if config.reverse {
(Some(ns_handle), self_ns_handle)
(namespace.as_ref(), "parent")
} else {
(None, ns_handle)
("parent", namespace.as_ref())
};
let nobody_uid = get_user_by_name("nobody")
.map(|user| user.uid())
.unwrap_or(65534);
let nobody_gid = get_group_by_name("nobody")
.map(|group| group.gid())
.unwrap_or(65534);
let mut command = Command::new("/proc/self/exe");
command
.arg("proxy")
.arg(listen_namespace)
.arg(config.source.to_string())
.arg(target_namespace)
.arg(config.target.to_string());
let source = config.source.clone();
spawn(move || {
let rt = match Builder::new_current_thread().enable_io().build() {
Ok(rt) => rt,
Err(error) => {
error!(%error, "Error setting up tokio runtime");
return;
}
};
rt.block_on(async {
if let Some(listen_namespace) = listen_namespace {
if let Err(error) = setns(listen_namespace, CloneFlags::CLONE_NEWNET) {
error!(%error, "Failed to join listen network namespace for proxy");
return;
}
}
let proxy = match Proxy::listen(&source) {
Ok(proxy) => proxy,
Err(error) => {
error!(%error, "Failed to listen to {source}");
return;
}
};
if let Err(error) = setns(target_namespace, CloneFlags::CLONE_NEWNET) {
error!(%error, "Failed to join target network namespace for proxy");
return;
}
// raw syscall since the libc `setuid`/`setgui` set it for the entire process
unsafe {
if let Err(error) = syscall!(Sysno::setgid, nobody_gid)
.and_then(|_| syscall!(Sysno::setuid, nobody_uid))
{
error!(%error, "Failed drop privileges for proxy thread");
}
}
proxy.run(destination, abort_reg, run_stats).await;
});
});
let child = command
.spawn()
.map_err(|error| ProxyError::Spawn { error })?;
Ok(ActiveProxy {
source: config.source.clone(),
destination: config.target.clone(),
abort,
stats,
child: Some(child),
})
}
}
impl Drop for ActiveProxy {
fn drop(&mut self) {
if let ForwardSource::Unix(path) = &self.source {
if let Err(error) = remove_file(path) {
error!(%error, "failed to remove unix socket");
}
let mut child = self.child.take().unwrap();
if let Err(error) = kill(Pid::from_raw(child.id() as i32), Some(SIGINT)) {
error!(%error, "failed to signal proxy process to stop");
}
self.abort.abort();
let source = self.source.clone();
spawn(move || {
if let Err(error) = child.wait() {
error!(%error, "failed to wait for proxy process");
}
// we do this here, since the proxy process won't have permissions for it anymore
if let ForwardSource::Unix(path) = &source {
if let Err(error) = remove_file(path) {
error!(%error, "failed to remove unix socket");
}
}
});
}
}
@ -139,53 +105,6 @@ impl PartialEq<ForwardConfig> for ActiveProxy {
}
}
#[derive(Default, Clone)]
pub struct ProxyStats {
open_connections: Arc<AtomicU64>,
total_connections: Arc<AtomicU64>,
/// Bytes proxied source -> destination
bytes_written: Arc<AtomicU64>,
/// Bytes proxied destination -> source
bytes_read: Arc<AtomicU64>,
}
impl ProxyStats {
pub fn open_connection(&self) {
self.open_connections.fetch_add(1, Ordering::Relaxed);
self.total_connections.fetch_add(1, Ordering::Relaxed);
}
pub fn close_connection(&self) {
self.open_connections.fetch_sub(1, Ordering::Relaxed);
}
pub fn open_connections(&self) -> u64 {
self.open_connections.load(Ordering::Relaxed)
}
pub fn total_connections(&self) -> u64 {
self.total_connections.load(Ordering::Relaxed)
}
pub fn add_read(&self, bytes: u64) {
self.bytes_read.fetch_add(bytes, Ordering::Relaxed);
}
pub fn add_written(&self, bytes: u64) {
self.bytes_written.fetch_add(bytes, Ordering::Relaxed);
}
/// Bytes proxied destination -> source
pub fn read(&self) -> u64 {
self.bytes_read.load(Ordering::Relaxed)
}
/// Bytes proxied source -> destination
pub fn written(&self) -> u64 {
self.bytes_written.load(Ordering::Relaxed)
}
}
fn open_namespace(path: impl AsRef<Path>) -> Result<File, ProxyError> {
let path = path.as_ref();
File::open(path).map_err(|error| ProxyError::OpenNamespace {
@ -193,3 +112,80 @@ fn open_namespace(path: impl AsRef<Path>) -> Result<File, ProxyError> {
path: path.into(),
})
}
pub fn proxy(
source_namespace: NamespaceName,
target_namespace: NamespaceName,
source: ForwardSource,
target: ForwardTarget,
) -> MainResult {
let (abort, abort_reg) = AbortHandle::new_pair();
let target_namespace = if target_namespace == "parent" {
open_namespace("/proc/self/ns/net")?
} else {
open_namespace(format!("/var/run/netns/{target_namespace}"))?
};
let listen_namespace = if source_namespace == "parent" {
None
} else {
Some(open_namespace(format!(
"/var/run/netns/{source_namespace}"
))?)
};
let nobody_uid = Uid::from(
get_user_by_name("nobody")
.map(|user| user.uid())
.unwrap_or(65534),
);
let nobody_gid = Gid::from(
get_group_by_name("nobody")
.map(|group| group.gid())
.unwrap_or(65534),
);
let rt = match Builder::new_current_thread().enable_io().build() {
Ok(rt) => rt,
Err(error) => {
error!(%error, "Error setting up tokio runtime");
return Err(error.into());
}
};
rt.block_on(async {
if let Some(listen_namespace) = listen_namespace {
if let Err(error) = setns(listen_namespace, CloneFlags::CLONE_NEWNET) {
error!(%error, "Failed to join listen network namespace for proxy");
return Err(error.into());
}
}
let proxy = match Proxy::listen(&source) {
Ok(proxy) => proxy,
Err(error) => {
error!(%error, "Failed to listen to {source}");
return Err(error.into());
}
};
if let Err(error) = setns(target_namespace, CloneFlags::CLONE_NEWNET) {
error!(%error, "Failed to join target network namespace for proxy");
return Err(error.into());
}
if let Err(error) = setgid(nobody_gid).and_then(|_| setuid(nobody_uid)) {
error!(%error, "Failed to drop privileges");
}
tokio::spawn(async move {
let _ = ctrl_c().await;
abort.abort();
});
proxy.run(target, abort_reg).await;
Ok(())
})
}

View file

@ -1,6 +1,6 @@
/// Loosely based on https://github.com/fooker/netns-proxy/blob/main/src/tcp.rs
use crate::config::{ForwardTarget, ForwardSource};
use crate::proxy::{ProxyError, ProxyStats};
use crate::proxy::{ProxyError};
use futures::TryStreamExt;
use futures::stream::{AbortRegistration, Abortable};
use std::fs::{remove_file, set_permissions};
@ -61,14 +61,13 @@ impl Proxy {
})
}
pub async fn run(self, target: ForwardTarget, abort: AbortRegistration, stats: ProxyStats) {
let proxy_stats = stats.clone();
pub async fn run(self, target: ForwardTarget, abort: AbortRegistration,) {
match self.socket {
ProxyListener::Tcp(socket) => {
run_tcp(socket, target.addr, abort, proxy_stats).await
run_tcp(socket, target.addr, abort).await
}
ProxyListener::Unix(socket) => {
run_unix(socket, target.addr, abort, proxy_stats).await
run_unix(socket, target.addr, abort).await
}
}
}
@ -78,16 +77,14 @@ async fn run_tcp(
socket: TcpListener,
target: SocketAddr,
abort: AbortRegistration,
stats: ProxyStats,
) {
let accepts = TcpListenerStream::new(socket).map_err(|error| ProxyError::Accept { error });
let mut accepts = pin!(Abortable::new(accepts, abort));
while let Some(client) = accepts.next().await {
stats.open_connection();
let result: Result<(), ProxyError> = async {
let client = client?;
let remote = client.peer_addr().ok();
proxy_stream(client, target, remote, stats.clone()).await
proxy_stream(client, target, remote).await
}
.await;
@ -101,15 +98,13 @@ async fn run_unix(
socket: UnixListener,
target: SocketAddr,
abort: AbortRegistration,
stats: ProxyStats,
) {
let accepts = UnixListenerStream::new(socket).map_err(|error| ProxyError::Accept { error });
let mut accepts = pin!(Abortable::new(accepts, abort));
while let Some(client) = accepts.next().await {
stats.open_connection();
let result: Result<(), ProxyError> = async {
let client = client?;
proxy_stream(client, target, None, stats.clone()).await
proxy_stream(client, target, None).await
}
.await;
@ -119,12 +114,11 @@ async fn run_unix(
}
}
#[instrument(skip(stream, proxy))]
#[instrument(skip(stream))]
async fn proxy_stream<Stream>(
mut stream: Stream,
target: SocketAddr,
remote: Option<SocketAddr>,
proxy: ProxyStats,
) -> Result<(), ProxyError>
where
Stream: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
@ -143,10 +137,7 @@ where
spawn(async move {
match copy_bidirectional(&mut stream, &mut upstream).await {
Ok((written, read)) => {
proxy.add_written(written);
proxy.add_read(read);
proxy.close_connection();
Ok(_) => {
trace!("Upstream connection closed");
}
Err(err) => {