basic netns management

This commit is contained in:
Robin Appelman 2025-10-30 18:16:28 +01:00
commit a4c7b3c1c9
17 changed files with 1555 additions and 0 deletions

126
src/config/destination.rs Normal file
View file

@ -0,0 +1,126 @@
use serde::de::{Error, Unexpected, Visitor};
use serde::{Deserialize, Deserializer};
use std::fmt::{Display, Formatter};
use std::net::{IpAddr, SocketAddr};
use std::str::FromStr;
#[derive(Debug, PartialEq, Clone, Hash, Eq)]
pub struct ForwardDestination {
addr: SocketAddr,
}
impl From<ForwardDestination> for SocketAddr {
fn from(value: ForwardDestination) -> Self {
value.addr
}
}
impl Display for ForwardDestination {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.addr)
}
}
impl<'de> Deserialize<'de> for ForwardDestination {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ForwardDestinationVisitor;
impl<'de> Visitor<'de> for ForwardDestinationVisitor {
type Value = ForwardDestination;
fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
formatter
.write_str("Either a port as integer, or a string containing a socket address")
}
fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
where
E: Error,
{
let v = v
.try_into()
.map_err(|_| E::invalid_value(Unexpected::Signed(v), &self))?;
self.visit_u16(v)
}
fn visit_u16<E>(self, v: u16) -> Result<Self::Value, E>
where
E: Error,
{
let ip = IpAddr::from([127, 0, 0, 1]);
Ok(ForwardDestination {
addr: SocketAddr::from((ip, v)),
})
}
fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
where
E: Error,
{
let v = v
.try_into()
.map_err(|_| E::invalid_value(Unexpected::Unsigned(v), &self))?;
self.visit_u16(v)
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: Error,
{
if let Ok(port) = u16::from_str(v) {
return self.visit_u16(port);
}
let addr = v
.parse()
.map_err(|_| E::invalid_value(Unexpected::Str(&v), &self))?;
Ok(ForwardDestination { addr })
}
}
deserializer.deserialize_any(ForwardDestinationVisitor)
}
}
#[test]
fn test_de() {
use serde_test::{Token, assert_de_tokens, assert_de_tokens_error};
let addr_str = "127.0.0.1:80";
let addr = SocketAddr::from_str("127.0.0.1:80").unwrap();
fn port_addr(port: u16) -> ForwardDestination {
ForwardDestination {
addr: SocketAddr::new(IpAddr::from([127, 0, 0, 1]), port),
}
}
assert_de_tokens(&ForwardDestination { addr }, &[Token::String(addr_str)]);
assert_de_tokens(&ForwardDestination { addr }, &[Token::Str(addr_str)]);
assert_de_tokens(&port_addr(80), &[Token::Str("80")]);
assert_de_tokens(&port_addr(80), &[Token::U8(80)]);
assert_de_tokens(&port_addr(80), &[Token::U16(80)]);
assert_de_tokens(&port_addr(80), &[Token::U64(80)]);
assert_de_tokens(&port_addr(80), &[Token::I8(80)]);
assert_de_tokens(&port_addr(80), &[Token::I16(80)]);
assert_de_tokens(&port_addr(80), &[Token::I64(80)]);
assert_de_tokens_error::<ForwardDestination>(
&[Token::I64(-80)],
"invalid value: integer `-80`, expected Either a port as integer, or a string containing a socket address",
);
assert_de_tokens_error::<ForwardDestination>(
&[Token::U64(12345678)],
"invalid value: integer `12345678`, expected Either a port as integer, or a string containing a socket address",
);
assert_de_tokens_error::<ForwardDestination>(
&[Token::Str("hello world")],
"invalid value: string \"hello world\", expected Either a port as integer, or a string containing a socket address",
);
assert_de_tokens_error::<ForwardDestination>(
&[Token::Str("localhost:80")],
"invalid value: string \"localhost:80\", expected Either a port as integer, or a string containing a socket address",
);
}

109
src/config/mod.rs Normal file
View file

@ -0,0 +1,109 @@
mod destination;
mod name;
mod source;
pub use crate::config::destination::ForwardDestination;
pub use crate::config::name::NamespaceName;
pub use crate::config::source::ForwardSource;
use serde::Deserialize;
use std::collections::HashSet;
use std::fs::read_to_string;
use std::path::{Path, PathBuf};
use thiserror::Error;
use toml::from_str;
#[derive(Debug)]
pub struct Config {
path: PathBuf,
pub namespaces: Vec<NamespaceConfig>,
}
impl Config {
pub fn load<P: AsRef<Path>>(path: P) -> Result<Config, ConfigError> {
let path = path.as_ref();
let raw = read_to_string(path).map_err(|error| ConfigError::Read {
error,
path: path.to_owned(),
})?;
let config: RawConfig = from_str(&raw).map_err(|error| ConfigError::Parse {
error,
path: path.to_owned(),
})?;
Ok(config
.validate(path)
.map_err(|error| ConfigError::Validation {
error,
path: path.to_owned(),
})?)
}
pub fn reload(&self) -> Result<Config, ConfigError> {
Self::load(&self.path)
}
}
#[derive(Deserialize, Debug)]
#[serde(deny_unknown_fields)]
struct RawConfig {
#[serde(default, rename = "namespace")]
pub namespaces: Vec<NamespaceConfig>,
}
impl RawConfig {
fn validate(self, path: &Path) -> Result<Config, ValidationError> {
let mut sources = HashSet::new();
for source in self
.namespaces
.iter()
.flat_map(|namespace| namespace.forward.iter())
.map(|forward| &forward.source)
{
if !sources.insert(source.clone()) {
return Err(ValidationError::DuplicateSource {
forward_source: source.clone(),
});
}
}
Ok(Config {
path: path.into(),
namespaces: self.namespaces,
})
}
}
#[derive(Deserialize, Debug)]
pub struct NamespaceConfig {
pub name: NamespaceName,
pub forward: Vec<ForwardConfig>,
}
#[derive(Deserialize, Debug)]
pub struct ForwardConfig {
pub source: ForwardSource,
pub destination: ForwardDestination,
}
#[derive(Debug, Error)]
pub enum ConfigError {
#[error("Error while reading config from {}: {error:#}", path.display())]
Read {
error: std::io::Error,
path: PathBuf,
},
#[error("Error while parsing config from {}: {error:#}", path.display())]
Parse {
error: toml::de::Error,
path: PathBuf,
},
#[error("Error while validating config from {}: {error:#}", path.display())]
Validation {
error: ValidationError,
path: PathBuf,
},
}
#[derive(Debug, Error)]
pub enum ValidationError {
#[error("Duplicate source in forwards: {forward_source}")]
DuplicateSource { forward_source: ForwardSource },
}

91
src/config/name.rs Normal file
View file

@ -0,0 +1,91 @@
use std::fmt::{Display, Formatter};
use std::path::Path;
use serde::{Deserialize, Deserializer};
use serde::de::{Error, Unexpected, Visitor};
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct NamespaceName(String);
impl Display for NamespaceName {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl AsRef<str> for NamespaceName {
fn as_ref(&self) -> &str {
self.0.as_ref()
}
}
impl AsRef<Path> for NamespaceName {
fn as_ref(&self) -> &Path {
self.0.as_ref()
}
}
impl From<NamespaceName> for String {
fn from(value: NamespaceName) -> Self {
value.0
}
}
impl<'de> Deserialize<'de> for NamespaceName {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct NamespaceNameVisitor;
impl<'de> Visitor<'de> for NamespaceNameVisitor {
type Value = NamespaceName;
fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
formatter
.write_str("A valid namespace name")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: Error,
{
if !validate_name(v) {
return Err(E::invalid_value(Unexpected::Str(v), &self));
}
Ok(NamespaceName(v.into()))
}
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: Error
{
if !validate_name(&v) {
return Err(E::invalid_value(Unexpected::Str(&v), &self));
}
Ok(NamespaceName(v))
}
}
deserializer.deserialize_any(NamespaceNameVisitor)
}
}
/// Check if a name follows the portable filename character set
fn validate_name(name: &str) -> bool {
if name.is_empty() {
return false;
}
name.bytes().all(|b| b.is_ascii_alphanumeric() || [b'_', b'.', b'-'].contains(&b))
}
#[test]
fn test_de() {
use serde_test::{Token, assert_de_tokens, assert_de_tokens_error};
assert_de_tokens(&NamespaceName("foo".into()), &[Token::String("foo")]);;
assert_de_tokens_error::<NamespaceName>(
&[Token::String("foo/bar")],
"invalid value: integer `-80`, expected Either a port as integer, or a string containing a socket address",
);
}

143
src/config/source.rs Normal file
View file

@ -0,0 +1,143 @@
use serde::de::{Error, Unexpected, Visitor};
use serde::{Deserialize, Deserializer};
use std::fmt::{Display, Formatter};
use std::net::{IpAddr, SocketAddr};
use std::path::PathBuf;
use std::str::FromStr;
#[derive(Debug, PartialEq, Clone, Hash, Eq)]
pub enum ForwardSource {
Unix(PathBuf),
Ip(SocketAddr),
}
impl Display for ForwardSource {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
ForwardSource::Unix(a) => write!(f, "{}", a.display()),
ForwardSource::Ip(a) => write!(f, "{a}"),
}
}
}
impl<'de> Deserialize<'de> for ForwardSource {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ForwardSourceVisitor;
impl<'de> Visitor<'de> for ForwardSourceVisitor {
type Value = ForwardSource;
fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
formatter.write_str("Either a port as integer, or a string containing a socket address or unix path")
}
fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
where
E: Error,
{
let v = v
.try_into()
.map_err(|_| E::invalid_value(Unexpected::Signed(v), &self))?;
self.visit_u16(v)
}
fn visit_u16<E>(self, v: u16) -> Result<Self::Value, E>
where
E: Error,
{
let ip = IpAddr::from([0, 0, 0, 0]);
Ok(ForwardSource::Ip(SocketAddr::from((ip, v))))
}
fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
where
E: Error,
{
let v = v
.try_into()
.map_err(|_| E::invalid_value(Unexpected::Unsigned(v), &self))?;
self.visit_u16(v)
}
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: Error,
{
if v.starts_with('/') {
Ok(ForwardSource::Unix(v.into()))
} else {
self.visit_str(&v)
}
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: Error,
{
if v.starts_with('/') {
Ok(ForwardSource::Unix(v.into()))
} else {
if let Ok(port) = u16::from_str(v) {
return self.visit_u16(port);
}
let addr = v
.parse()
.map_err(|_| E::invalid_value(Unexpected::Str(&v), &self))?;
Ok(ForwardSource::Ip(addr))
}
}
}
deserializer.deserialize_any(ForwardSourceVisitor)
}
}
#[test]
fn test_de() {
use serde_test::{Token, assert_de_tokens, assert_de_tokens_error};
let addr_str = "127.0.0.1:80";
let addr = SocketAddr::from_str("127.0.0.1:80").unwrap();
fn port_addr(port: u16) -> ForwardSource {
ForwardSource::Ip(SocketAddr::new(IpAddr::from([0, 0, 0, 0]), port))
}
assert_de_tokens(
&ForwardSource::Unix("/test/foo".into()),
&[Token::String("/test/foo")],
);
assert_de_tokens(
&ForwardSource::Unix("/test/foo".into()),
&[Token::Str("/test/foo")],
);
assert_de_tokens(&ForwardSource::Ip(addr), &[Token::String(addr_str)]);
assert_de_tokens(&ForwardSource::Ip(addr), &[Token::Str(addr_str)]);
assert_de_tokens(&port_addr(80), &[Token::Str("80")]);
assert_de_tokens(&port_addr(80), &[Token::U8(80)]);
assert_de_tokens(&port_addr(80), &[Token::U16(80)]);
assert_de_tokens(&port_addr(80), &[Token::U64(80)]);
assert_de_tokens(&port_addr(80), &[Token::I8(80)]);
assert_de_tokens(&port_addr(80), &[Token::I16(80)]);
assert_de_tokens(&port_addr(80), &[Token::I64(80)]);
assert_de_tokens_error::<ForwardSource>(
&[Token::I64(-80)],
"invalid value: integer `-80`, expected Either a port as integer, or a string containing a socket address or unix path",
);
assert_de_tokens_error::<ForwardSource>(
&[Token::U64(12345678)],
"invalid value: integer `12345678`, expected Either a port as integer, or a string containing a socket address or unix path",
);
assert_de_tokens_error::<ForwardSource>(
&[Token::Str("hello world")],
"invalid value: string \"hello world\", expected Either a port as integer, or a string containing a socket address or unix path",
);
assert_de_tokens_error::<ForwardSource>(
&[Token::Str("localhost:80")],
"invalid value: string \"localhost:80\", expected Either a port as integer, or a string containing a socket address or unix path",
);
}

55
src/daemon/mod.rs Normal file
View file

@ -0,0 +1,55 @@
mod namespace;
use crate::config::{Config, NamespaceName};
use crate::daemon::namespace::{NamespaceError, NetNs};
use main_error::MainResult;
use sd_notify::{notify, NotifyState};
use std::io::Error as IoError;
use thiserror::Error;
use tokio::runtime::Runtime;
use tokio::signal::ctrl_c;
pub fn daemon(config: Config) -> MainResult {
let rt = Runtime::new()?;
Ok(rt.block_on(daemon_async(config))?)
}
async fn daemon_async(config: Config) -> Result<(), DaemonError> {
for namespace in &config.namespaces {
println!("{}:", namespace.name);
for forward in &namespace.forward {
println!(" {} => {}", forward.source, forward.destination);
}
}
let namespaces = config
.namespaces
.iter()
.map(|ns| ActiveNamespace::new(&ns.name))
.collect::<Result<Vec<_>, _>>()?;
// now the namespaces are setup, we can tell systemd to start any service depending on them
notify(true, &[NotifyState::Ready]).map_err(DaemonError::Notify)?;
let _ = ctrl_c().await;
Ok(())
}
struct ActiveNamespace {
ns: NetNs,
}
impl ActiveNamespace {
pub fn new(name: &NamespaceName) -> Result<Self, DaemonError> {
let ns = NetNs::new(name)?;
Ok(ActiveNamespace { ns })
}
}
#[derive(Debug, Error)]
pub enum DaemonError {
#[error(transparent)]
Namespace(#[from] NamespaceError),
#[error("Error sending notification to systemd: {0:#}")]
Notify(IoError)
}

87
src/daemon/namespace.rs Normal file
View file

@ -0,0 +1,87 @@
use crate::config::NamespaceName;
use nix::errno::Errno;
use nix::mount::{MsFlags, mount, umount};
use nix::sched::{CloneFlags, unshare};
use std::fs::{File, create_dir_all, remove_file};
use std::io::{Error as IoError, ErrorKind};
use std::path::{Path, PathBuf};
use std::thread::{JoinHandle, spawn};
use thiserror::Error;
use tracing::{debug, error};
pub struct NetNs {
path: PathBuf,
}
impl NetNs {
/// Create a new named network namespace that will be removed when dropped
pub fn new(name: &NamespaceName) -> Result<Self, NamespaceError> {
debug!(%name, "creating network namespace");
let parent = Path::new("/var/run/netns");
create_dir_all(parent).map_err(NamespaceError::Parent)?;
let path = parent.join(name);
let mount_path = path.clone();
let _ =
File::create_new(&path).map_err(|error| NamespaceError::from_create(name, error))?;
let handle: JoinHandle<Result<(), NamespaceError>> = spawn(move || {
unshare(CloneFlags::CLONE_NEWNET).map_err(NamespaceError::Unshare)?;
mount(
Some("/proc/self/ns/net"),
&mount_path,
Option::<&str>::None,
MsFlags::MS_BIND,
Option::<&str>::None,
)
.map_err(NamespaceError::from_mount)?;
Ok(())
});
handle.join().unwrap()?;
Ok(NetNs { path })
}
}
impl Drop for NetNs {
fn drop(&mut self) {
let name = self.path.file_name().unwrap().to_str().unwrap();
debug!(name, "deleting network namespace");
if let Err(error) = umount(&self.path) {
error!(%error, path = %self.path.display(), "Failed to unmount network namespace");
}
if let Err(error) = remove_file(&self.path) {
error!(%error, path = %self.path.display(), "Failed to remove namespace file");
}
}
}
#[derive(Debug, Error)]
pub enum NamespaceError {
#[error("Failed to create parent directory for namespaces (/var/run/netns): {0:#}")]
Parent(IoError),
#[error("Network namespace {0} already exists")]
AlreadyExists(NamespaceName),
#[error("Failed to create namespace file {}: {error:#}", path.display())]
Create { path: PathBuf, error: IoError },
#[error("Unexpected error while creating new network namespace: {0:}")]
Unshare(Errno),
#[error("Unexpected error while binding new network namespace: {0:}")]
Bind(Errno),
}
impl NamespaceError {
fn from_mount(errno: Errno) -> Self {
// todo more specific errors?
NamespaceError::Bind(errno)
}
fn from_create(name: &NamespaceName, error: IoError) -> Self {
match error.kind() {
ErrorKind::AlreadyExists => NamespaceError::AlreadyExists(name.clone()),
_ => NamespaceError::Create {
path: Path::new("/var/run/netns").join(name),
error,
},
}
}
}

43
src/main.rs Normal file
View file

@ -0,0 +1,43 @@
use std::path::{PathBuf};
use clap::{Parser, Subcommand};
use main_error::MainResult;
use crate::config::Config;
use crate::daemon::daemon;
mod config;
mod daemon;
#[derive(Parser, Debug)]
pub struct Args {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand, Debug)]
enum Commands {
/// Start the netnsd daemon
Daemon {
/// Location of the config file
#[clap(short, long, default_value = "/etc/netnsd/netnsd")]
config: PathBuf,
},
/// Signal a running daemon to reload it's configuration
Reload,
}
fn main() -> MainResult {
let args: Args = Args::parse();
tracing_subscriber::fmt::init();
match args.command {
Commands::Daemon { config } => {
let config = Config::load(config)?;
daemon(config)
}
Commands::Reload => reload()
}
}
fn reload() -> MainResult {
todo!()
}