it seems to be working now

This commit is contained in:
Robin Appelman 2025-10-31 01:46:47 +01:00
commit b4cf0acb44
11 changed files with 398 additions and 64 deletions

View file

@ -86,7 +86,7 @@ impl<'de> Deserialize<'de> for ForwardDestination {
#[test]
fn test_de() {
use serde_test::{Token, assert_de_tokens, assert_de_tokens_error};
use serde_test::{assert_de_tokens, assert_de_tokens_error, Token};
let addr_str = "127.0.0.1:80";
let addr = SocketAddr::from_str("127.0.0.1:80").unwrap();

View file

@ -1,7 +1,7 @@
use serde::de::{Error, Unexpected, Visitor};
use serde::{Deserialize, Deserializer};
use std::fmt::{Display, Formatter};
use std::path::Path;
use serde::{Deserialize, Deserializer};
use serde::de::{Error, Unexpected, Visitor};
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct NamespaceName(String);
@ -41,8 +41,7 @@ impl<'de> Deserialize<'de> for NamespaceName {
type Value = NamespaceName;
fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
formatter
.write_str("A valid namespace name")
formatter.write_str("A valid namespace name")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
@ -57,7 +56,7 @@ impl<'de> Deserialize<'de> for NamespaceName {
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: Error
E: Error,
{
if !validate_name(&v) {
return Err(E::invalid_value(Unexpected::Str(&v), &self));
@ -75,17 +74,18 @@ fn validate_name(name: &str) -> bool {
if name.is_empty() {
return false;
}
name.bytes().all(|b| b.is_ascii_alphanumeric() || [b'_', b'.', b'-'].contains(&b))
name.bytes()
.all(|b| b.is_ascii_alphanumeric() || [b'_', b'.', b'-'].contains(&b))
}
#[test]
fn test_de() {
use serde_test::{Token, assert_de_tokens, assert_de_tokens_error};
use serde_test::{assert_de_tokens, assert_de_tokens_error, Token};
assert_de_tokens(&NamespaceName("foo".into()), &[Token::String("foo")]);
assert_de_tokens_error::<NamespaceName>(
&[Token::String("foo/bar")],
"invalid value: integer `-80`, expected Either a port as integer, or a string containing a socket address",
"invalid value: string \"foo/bar\", expected A valid namespace name",
);
}

View file

@ -97,7 +97,7 @@ impl<'de> Deserialize<'de> for ForwardSource {
#[test]
fn test_de() {
use serde_test::{Token, assert_de_tokens, assert_de_tokens_error};
use serde_test::{assert_de_tokens, assert_de_tokens_error, Token};
let addr_str = "127.0.0.1:80";
let addr = SocketAddr::from_str("127.0.0.1:80").unwrap();

68
src/daemon/link.rs Normal file
View file

@ -0,0 +1,68 @@
use neli::consts::nl::NlmF;
use neli::consts::rtnl::Ifla;
use neli::consts::rtnl::RtAddrFamily;
use neli::consts::rtnl::Rtm;
use neli::consts::socket::NlFamily;
use neli::err::RouterError;
use neli::nl::NlPayload;
use neli::router::synchronous::NlRouter;
use neli::rtnl::Ifinfomsg;
use neli::rtnl::IfinfomsgBuilder;
use neli::utils::Groups;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum LinkError {
#[error("failed to communicate with netlink")]
Netlink,
#[error("failed to code netlink response")]
Parse,
}
impl<T, P> From<RouterError<T, P>> for LinkError {
fn from(_value: RouterError<T, P>) -> Self {
LinkError::Netlink
}
}
/// Set a link to UP
pub fn up(link_name: &str) -> Result<(), LinkError> {
// I honestly don't really know how this code works
// It's mostly a copy from one of neli's examples and seems to do what it needs to
let (rtnl, _) = NlRouter::connect(NlFamily::Route, None, Groups::empty())?;
rtnl.enable_ext_ack(true)?;
rtnl.enable_strict_checking(true)?;
let ifinfomsg = IfinfomsgBuilder::default()
.ifi_family(RtAddrFamily::Inet)
.build()
.unwrap();
let recv = rtnl.send::<_, _, Rtm, Ifinfomsg>(
Rtm::Getlink,
NlmF::DUMP | NlmF::ACK,
NlPayload::Payload(ifinfomsg),
)?;
for response in recv {
if let Some(payload) = response?.get_payload() {
let name = payload
.rtattrs()
.get_attr_handle()
.get_attr_payload_as_with_len::<String>(Ifla::Ifname)
.map_err(|_| LinkError::Parse)?;
if name == link_name {
let up_msg = IfinfomsgBuilder::default()
.ifi_family(RtAddrFamily::Inet)
.ifi_index(*payload.ifi_index())
.up()
.build()
.unwrap();
rtnl.send::<_, _, Rtm, Ifinfomsg>(
Rtm::Setlink,
NlmF::ACK,
NlPayload::Payload(up_msg),
)?;
}
}
}
Ok(())
}

View file

@ -1,5 +1,6 @@
mod namespace;
mod proxy;
pub mod link;
use crate::config::{Config, ForwardConfig, NamespaceConfig, NamespaceName};
use crate::daemon::namespace::{NamespaceError, NetNs};
@ -163,7 +164,7 @@ impl ActiveNamespace {
for new in &config.forward {
if !self.has_forward(new) {
self.proxies.push(ActiveProxy::new(new)?);
self.proxies.push(ActiveProxy::new(new, &config.name)?);
}
}

View file

@ -1,11 +1,12 @@
use crate::config::NamespaceName;
use nix::errno::Errno;
use nix::mount::{MsFlags, mount, umount};
use nix::sched::{CloneFlags, unshare};
use std::fs::{File, create_dir_all, remove_file};
use nix::mount::{mount, umount, MsFlags};
use nix::sched::{clone, CloneFlags};
use nix::sys::signal::Signal;
use nix::sys::wait::{waitpid, WaitStatus};
use std::fs::{create_dir_all, remove_file, File};
use std::io::{Error as IoError, ErrorKind};
use std::path::{Path, PathBuf};
use std::thread::{JoinHandle, spawn};
use thiserror::Error;
use tracing::{error, info};
@ -34,19 +35,48 @@ impl NetNs {
Err(e) => return Err(NamespaceError::from_create(name, e)),
}
let handle: JoinHandle<Result<(), NamespaceError>> = spawn(move || {
unshare(CloneFlags::CLONE_NEWNET).map_err(NamespaceError::Unshare)?;
mount(
Some("/proc/self/ns/net"),
&mount_path,
Option::<&str>::None,
MsFlags::MS_BIND,
Option::<&str>::None,
let mut stack = vec![0; 8 * 1024 * 1024];
let pid = unsafe {
clone(
Box::new(move || {
if let Err(e) = mount(
Some("/proc/self/ns/net"),
&mount_path,
Option::<&str>::None,
MsFlags::MS_BIND,
Option::<&str>::None,
) {
return e as i32 as isize;
}
if let Err(error) = super::link::up("lo") {
error!(%error, "error setting link up");
return 1;
}
0
}),
&mut stack,
CloneFlags::CLONE_NEWNET,
Some(Signal::SIGCHLD as i32),
)
.map_err(NamespaceError::from_mount)?;
Ok(())
});
handle.join().unwrap()?;
}
.map_err(NamespaceError::Clone)?;
match waitpid(pid, None).map_err(NamespaceError::Wait)? {
WaitStatus::Exited(_, exit) => {
if exit > 0 {
if let Err(error) = remove_file(&path) {
error!(%error, path = %path.display(), "Failed to remove namespace file after mount failure");
}
return Err(NamespaceError::from_mount(Errno::from_raw(exit)));
}
}
status => {
error!(?status, "unexpected wait status");
}
}
Ok(NetNs {
name: name.clone(),
path,
@ -79,10 +109,12 @@ pub enum NamespaceError {
Parent(IoError),
#[error("Failed to create namespace file {}: {error:#}", path.display())]
Create { path: PathBuf, error: IoError },
#[error("Unexpected error while creating new network namespace: {0:}")]
Unshare(Errno),
#[error("Unexpected error while creating new network namespace with clone: {0:}")]
Clone(Errno),
#[error("Unexpected error while binding new network namespace: {0:}")]
Bind(Errno),
#[error("Unexpected error while waiting for network namespace thread: {0:}")]
Wait(Errno),
}
impl NamespaceError {

View file

@ -1,14 +1,18 @@
mod tcp;
use std::fs::remove_file;
use crate::config::{ForwardConfig, ForwardDestination, ForwardSource};
use crate::config::{ForwardConfig, ForwardDestination, ForwardSource, NamespaceName};
use crate::daemon::proxy::tcp::Proxy;
use futures::future::AbortHandle;
use nix::sched::{CloneFlags, setns};
use std::fs::{File, remove_file};
use std::io::Error as IoError;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::thread::spawn;
use thiserror::Error;
use tokio::runtime::Builder;
use tracing::error;
#[derive(Debug, Error)]
@ -25,6 +29,8 @@ pub enum ProxyError {
destination: SocketAddr,
error: IoError,
},
#[error("Failed to open namespace file {}: {error:#}", path.display())]
OpenNamespace { path: PathBuf, error: IoError },
}
pub struct ActiveProxy {
@ -35,9 +41,44 @@ pub struct ActiveProxy {
}
impl ActiveProxy {
pub fn new(config: &ForwardConfig) -> Result<ActiveProxy, ProxyError> {
pub fn new(
config: &ForwardConfig,
namespace: &NamespaceName,
) -> Result<ActiveProxy, ProxyError> {
let proxy = Proxy::listen(config.source.clone())?;
Ok(proxy.run(config.destination.clone()))
let stats = ProxyStats::default();
let (abort, abort_reg) = AbortHandle::new_pair();
let destination = config.destination.clone();
let run_stats = stats.clone();
let ns_path = PathBuf::from(format!("/var/run/netns/{namespace}"));
let ns_handle = File::open(&ns_path).map_err(|error| ProxyError::OpenNamespace {
error,
path: ns_path,
})?;
spawn(move || match setns(ns_handle, CloneFlags::CLONE_NEWNET) {
Ok(_) => {
let rt = match Builder::new_current_thread().enable_io().build() {
Ok(rt) => rt,
Err(error) => {
error!(%error, "Error setting up tokio runtime");
return;
}
};
rt.block_on(proxy.run(destination, abort_reg, run_stats));
}
Err(error) => {
error!(%error, "Failed to join network namespace for proxy");
}
});
Ok(ActiveProxy {
source: config.source.clone(),
destination: config.destination.clone(),
abort,
stats,
})
}
}

View file

@ -1,11 +1,12 @@
use std::fs::remove_file;
/// Loosely based on https://github.com/fooker/netns-proxy/blob/main/src/tcp.rs
use crate::config::{ForwardDestination, ForwardSource};
use crate::daemon::proxy::{ActiveProxy, ProxyError, ProxyStats};
use crate::daemon::proxy::{ProxyError, ProxyStats};
use futures::TryStreamExt;
use futures::stream::{AbortHandle, AbortRegistration, Abortable};
use futures::stream::{AbortRegistration, Abortable};
use std::fs::{remove_file, set_permissions};
use std::io::Error as IoError;
use std::net::SocketAddr;
use std::os::unix::fs::PermissionsExt;
use std::pin::pin;
use tokio::io::{AsyncRead, AsyncWrite, copy_bidirectional};
use tokio::net::{TcpListener, TcpSocket, TcpStream, UnixListener};
@ -16,7 +17,6 @@ use tracing::{Level, debug, error, instrument, span, trace, warn};
#[derive(Debug)]
pub struct Proxy {
source: ForwardSource,
socket: ProxyListener,
}
@ -41,44 +41,35 @@ impl Proxy {
let socket = match &bind {
ForwardSource::Unix(path) => {
let _ = remove_file(path);
UnixListener::bind(path).map(ProxyListener::Unix)
},
UnixListener::bind(path).map(|listener| {
if let Err(error) = set_permissions(path, PermissionsExt::from_mode(0o666)) {
error!(%error, "failed to set socket permissions");
}
ProxyListener::Unix(listener)
})
}
ForwardSource::Ip(addr) => bind_tcp(*addr).map(ProxyListener::Tcp),
}
.map_err(|error| ProxyError::Bind {
address: bind.clone(),
address: bind,
error,
})?;
debug!("Created TCP socket");
Ok(Self {
source: bind,
socket,
})
}
pub fn run(self, target: ForwardDestination) -> ActiveProxy {
let (abort_handle, abort) = AbortHandle::new_pair();
let destination = target.clone();
let stats = ProxyStats::default();
pub async fn run(self, target: ForwardDestination, abort: AbortRegistration, stats: ProxyStats) {
let proxy_stats = stats.clone();
spawn(async move {
match self.socket {
ProxyListener::Tcp(socket) => {
run_tcp(socket, target.addr, abort, proxy_stats).await
}
ProxyListener::Unix(socket) => {
run_unix(socket, target.addr, abort, proxy_stats).await
}
match self.socket {
ProxyListener::Tcp(socket) => {
run_tcp(socket, target.addr, abort, proxy_stats).await
}
ProxyListener::Unix(socket) => {
run_unix(socket, target.addr, abort, proxy_stats).await
}
});
ActiveProxy {
source: self.source,
destination,
abort: abort_handle,
stats,
}
}
}