mirror of
https://codeberg.org/icewind/netnsd.git
synced 2026-06-03 09:04:07 +02:00
move proxies to sub-processes
This commit is contained in:
parent
b595282810
commit
e672e11f09
9 changed files with 217 additions and 199 deletions
38
Cargo.lock
generated
38
Cargo.lock
generated
|
|
@ -433,15 +433,6 @@ version = "0.5.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
||||
|
||||
[[package]]
|
||||
name = "humansize"
|
||||
version = "2.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6cb51c9a029ddc91b07a787f1d86b53ccfa49b0e86688c946ebe8d3555685dd7"
|
||||
dependencies = [
|
||||
"libm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ident_case"
|
||||
version = "1.0.1"
|
||||
|
|
@ -476,12 +467,6 @@ version = "0.2.177"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976"
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
version = "0.2.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de"
|
||||
|
||||
[[package]]
|
||||
name = "lock_api"
|
||||
version = "0.4.14"
|
||||
|
|
@ -579,14 +564,12 @@ dependencies = [
|
|||
"either",
|
||||
"futures",
|
||||
"futures-concurrency",
|
||||
"humansize",
|
||||
"main_error",
|
||||
"neli",
|
||||
"nix",
|
||||
"sd-notify",
|
||||
"serde",
|
||||
"serde_test",
|
||||
"syscalls",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
|
|
@ -813,17 +796,6 @@ dependencies = [
|
|||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_repr"
|
||||
version = "0.1.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_spanned"
|
||||
version = "1.0.3"
|
||||
|
|
@ -911,16 +883,6 @@ dependencies = [
|
|||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syscalls"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "90db46b5b4962319605d435986c775ea45a0ad2561c09e1d5372b89afeb49cf4"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_repr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "2.0.17"
|
||||
|
|
|
|||
|
|
@ -14,15 +14,13 @@ thiserror = "2.0.17"
|
|||
tracing = "0.1.41"
|
||||
tracing-subscriber = "0.3.20"
|
||||
main_error = "0.1.2"
|
||||
nix = { version = "0.30.1", features = ["mount", "sched"] }
|
||||
nix = { version = "0.30.1", features = ["mount", "sched", "user", "signal"] }
|
||||
sd-notify = "0.4.5"
|
||||
futures = "0.3.31"
|
||||
futures-concurrency = "7.6.3"
|
||||
humansize = { version = "2.1.3", features = ["no_alloc"] }
|
||||
neli = "0.7.1"
|
||||
either = "1.15.0"
|
||||
uzers = "0.12.1"
|
||||
syscalls = "0.7.0"
|
||||
|
||||
[dev-dependencies]
|
||||
serde_test = "1.0.177"
|
||||
|
|
@ -3,6 +3,8 @@ use serde::{Deserialize, Deserializer};
|
|||
use std::ffi::OsString;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::path::Path;
|
||||
use std::str::FromStr;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
|
||||
pub struct NamespaceName(String);
|
||||
|
|
@ -38,6 +40,12 @@ impl AsRef<Path> for NamespaceName {
|
|||
}
|
||||
}
|
||||
|
||||
impl PartialEq<&str> for NamespaceName {
|
||||
fn eq(&self, other: &&str) -> bool {
|
||||
self.0 == *other
|
||||
}
|
||||
}
|
||||
|
||||
impl From<NamespaceName> for String {
|
||||
fn from(value: NamespaceName) -> Self {
|
||||
value.0
|
||||
|
|
@ -83,6 +91,23 @@ impl<'de> Deserialize<'de> for NamespaceName {
|
|||
}
|
||||
}
|
||||
|
||||
impl FromStr for NamespaceName {
|
||||
type Err = InvalidNamespaceNameError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
if !validate_name(s) {
|
||||
return Err(InvalidNamespaceNameError { name: s.into() });
|
||||
}
|
||||
Ok(NamespaceName(s.into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("invalid name for namespace: '{name}'")]
|
||||
pub struct InvalidNamespaceNameError {
|
||||
name: String,
|
||||
}
|
||||
|
||||
/// Check if a name follows the portable filename character set
|
||||
fn validate_name(name: &str) -> bool {
|
||||
if name.is_empty() {
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ use std::fmt::{Display, Formatter};
|
|||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::path::PathBuf;
|
||||
use std::str::FromStr;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq)]
|
||||
pub enum ForwardSource {
|
||||
|
|
@ -77,17 +78,8 @@ impl<'de> Deserialize<'de> for ForwardSource {
|
|||
where
|
||||
E: Error,
|
||||
{
|
||||
if v.starts_with('/') {
|
||||
Ok(ForwardSource::Unix(v.into()))
|
||||
} else {
|
||||
if let Ok(port) = u16::from_str(v) {
|
||||
return self.visit_u16(port);
|
||||
}
|
||||
let addr = v
|
||||
.parse()
|
||||
.map_err(|_| E::invalid_value(Unexpected::Str(v), &self))?;
|
||||
Ok(ForwardSource::Ip(addr))
|
||||
}
|
||||
v.parse()
|
||||
.map_err(|_| E::invalid_value(Unexpected::Str(v), &self))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -95,6 +87,31 @@ impl<'de> Deserialize<'de> for ForwardSource {
|
|||
}
|
||||
}
|
||||
|
||||
impl FromStr for ForwardSource {
|
||||
type Err = InvalidForwardSource;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
if s.starts_with('/') {
|
||||
Ok(ForwardSource::Unix(s.into()))
|
||||
} else {
|
||||
if let Ok(port) = u16::from_str(s) {
|
||||
let ip = IpAddr::from([0, 0, 0, 0]);
|
||||
return Ok(ForwardSource::Ip(SocketAddr::from((ip, port))));
|
||||
}
|
||||
let addr = s
|
||||
.parse()
|
||||
.map_err(|_| InvalidForwardSource { forward: s.into() })?;
|
||||
Ok(ForwardSource::Ip(addr))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("forward source '{forward}' is not a valid unix path or socket address")]
|
||||
pub struct InvalidForwardSource {
|
||||
forward: String,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_de() {
|
||||
use serde_test::{assert_de_tokens, assert_de_tokens_error, Token};
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ use serde::{Deserialize, Deserializer};
|
|||
use std::fmt::{Display, Formatter};
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::str::FromStr;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq)]
|
||||
pub struct ForwardTarget {
|
||||
|
|
@ -70,13 +71,8 @@ impl<'de> Deserialize<'de> for ForwardTarget {
|
|||
where
|
||||
E: Error,
|
||||
{
|
||||
if let Ok(port) = u16::from_str(v) {
|
||||
return self.visit_u16(port);
|
||||
}
|
||||
let addr = v
|
||||
.parse()
|
||||
.map_err(|_| E::invalid_value(Unexpected::Str(v), &self))?;
|
||||
Ok(ForwardTarget { addr })
|
||||
v.parse()
|
||||
.map_err(|_| E::invalid_value(Unexpected::Str(v), &self))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -84,6 +80,29 @@ impl<'de> Deserialize<'de> for ForwardTarget {
|
|||
}
|
||||
}
|
||||
|
||||
impl FromStr for ForwardTarget {
|
||||
type Err = InvalidForwardTarget;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
if let Ok(port) = u16::from_str(s) {
|
||||
let ip = IpAddr::from([127, 0, 0, 1]);
|
||||
return Ok(ForwardTarget {
|
||||
addr: SocketAddr::from((ip, port)),
|
||||
});
|
||||
}
|
||||
let addr = s
|
||||
.parse()
|
||||
.map_err(|_| InvalidForwardTarget { forward: s.into() })?;
|
||||
Ok(ForwardTarget { addr })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("forward source '{forward}' is not a valid socket address")]
|
||||
pub struct InvalidForwardTarget {
|
||||
forward: String,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_de() {
|
||||
use serde_test::{assert_de_tokens, assert_de_tokens_error, Token};
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ use crate::proxy::{ActiveProxy, ProxyError};
|
|||
use futures::FutureExt;
|
||||
use futures::StreamExt;
|
||||
use futures_concurrency::stream::Merge;
|
||||
use humansize::{BINARY, SizeFormatter};
|
||||
use main_error::MainResult;
|
||||
use sd_notify::{NotifyState, notify};
|
||||
use std::io::Error as IoError;
|
||||
|
|
@ -78,13 +77,9 @@ async fn daemon_async(mut config: Config) -> Result<(), DaemonError> {
|
|||
println!("{}:", namespace.name());
|
||||
for proxy in &namespace.proxies {
|
||||
println!(
|
||||
" {} => {} {} connections ({} active), {} sent to namespace, {} received from namespace",
|
||||
" {} => {}",
|
||||
proxy.source,
|
||||
proxy.destination,
|
||||
proxy.stats.total_connections(),
|
||||
proxy.stats.open_connections(),
|
||||
SizeFormatter::new(proxy.stats.written(), BINARY),
|
||||
SizeFormatter::new(proxy.stats.read(), BINARY),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
19
src/main.rs
19
src/main.rs
|
|
@ -1,9 +1,10 @@
|
|||
use std::path::{PathBuf};
|
||||
use clap::{Parser, Subcommand};
|
||||
use main_error::MainResult;
|
||||
use crate::config::Config;
|
||||
use crate::config::{Config, ForwardSource, ForwardTarget, NamespaceName};
|
||||
use crate::daemon::daemon;
|
||||
use crate::down::down;
|
||||
use crate::proxy::proxy;
|
||||
use crate::up::up;
|
||||
|
||||
mod config;
|
||||
|
|
@ -38,6 +39,17 @@ enum Commands {
|
|||
Down,
|
||||
/// Signal a running daemon to reload it's configuration
|
||||
Reload,
|
||||
/// Create the configured namespaces
|
||||
Proxy {
|
||||
/// Namespace to proxy connections from, use "parent" to listen in the namespace this command was spawned into
|
||||
source_namespace: NamespaceName,
|
||||
/// Source address to listen on
|
||||
source: ForwardSource,
|
||||
/// Namespace to proxy connections to
|
||||
target_namespace: NamespaceName,
|
||||
/// Target address to connect to
|
||||
target: ForwardTarget,
|
||||
},
|
||||
}
|
||||
|
||||
fn main() -> MainResult {
|
||||
|
|
@ -56,7 +68,10 @@ fn main() -> MainResult {
|
|||
Commands::Down => {
|
||||
down()
|
||||
}
|
||||
Commands::Reload => reload()
|
||||
Commands::Reload => reload(),
|
||||
Commands::Proxy {source, target, source_namespace, target_namespace} => {
|
||||
proxy(source_namespace, target_namespace, source, target)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
228
src/proxy/mod.rs
228
src/proxy/mod.rs
|
|
@ -3,17 +3,19 @@ mod tcp;
|
|||
use crate::config::{ForwardConfig, ForwardSource, ForwardTarget, NamespaceName};
|
||||
use crate::proxy::tcp::Proxy;
|
||||
use futures::future::AbortHandle;
|
||||
use main_error::MainResult;
|
||||
use nix::sched::{CloneFlags, setns};
|
||||
use nix::sys::signal::{SIGINT, kill};
|
||||
use nix::unistd::{Gid, Pid, Uid, setgid, setuid};
|
||||
use std::fs::{File, remove_file};
|
||||
use std::io::Error as IoError;
|
||||
use std::net::SocketAddr;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::process::{Child, Command};
|
||||
use std::thread::spawn;
|
||||
use syscalls::{Sysno, syscall};
|
||||
use thiserror::Error;
|
||||
use tokio::runtime::Builder;
|
||||
use tokio::signal::ctrl_c;
|
||||
use tracing::error;
|
||||
use uzers::{get_group_by_name, get_user_by_name};
|
||||
|
||||
|
|
@ -33,13 +35,14 @@ pub enum ProxyError {
|
|||
},
|
||||
#[error("Failed to open namespace file {}: {error:#}", path.display())]
|
||||
OpenNamespace { path: PathBuf, error: IoError },
|
||||
#[error("Failed to spawn proxy: {error:#}")]
|
||||
Spawn { error: std::io::Error },
|
||||
}
|
||||
|
||||
pub struct ActiveProxy {
|
||||
pub source: ForwardSource,
|
||||
pub destination: ForwardTarget,
|
||||
abort: AbortHandle,
|
||||
pub stats: ProxyStats,
|
||||
child: Option<Child>,
|
||||
}
|
||||
|
||||
impl ActiveProxy {
|
||||
|
|
@ -47,89 +50,52 @@ impl ActiveProxy {
|
|||
config: &ForwardConfig,
|
||||
namespace: &NamespaceName,
|
||||
) -> Result<ActiveProxy, ProxyError> {
|
||||
let stats = ProxyStats::default();
|
||||
let (abort, abort_reg) = AbortHandle::new_pair();
|
||||
|
||||
let destination = config.target.clone();
|
||||
let run_stats = stats.clone();
|
||||
|
||||
let ns_handle = open_namespace(format!("/var/run/netns/{namespace}"))?;
|
||||
let self_ns_handle = open_namespace("/proc/self/ns/net")?;
|
||||
|
||||
let (listen_namespace, target_namespace) = if config.reverse {
|
||||
(Some(ns_handle), self_ns_handle)
|
||||
(namespace.as_ref(), "parent")
|
||||
} else {
|
||||
(None, ns_handle)
|
||||
("parent", namespace.as_ref())
|
||||
};
|
||||
|
||||
let nobody_uid = get_user_by_name("nobody")
|
||||
.map(|user| user.uid())
|
||||
.unwrap_or(65534);
|
||||
let nobody_gid = get_group_by_name("nobody")
|
||||
.map(|group| group.gid())
|
||||
.unwrap_or(65534);
|
||||
let mut command = Command::new("/proc/self/exe");
|
||||
command
|
||||
.arg("proxy")
|
||||
.arg(listen_namespace)
|
||||
.arg(config.source.to_string())
|
||||
.arg(target_namespace)
|
||||
.arg(config.target.to_string());
|
||||
|
||||
let source = config.source.clone();
|
||||
spawn(move || {
|
||||
let rt = match Builder::new_current_thread().enable_io().build() {
|
||||
Ok(rt) => rt,
|
||||
Err(error) => {
|
||||
error!(%error, "Error setting up tokio runtime");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
rt.block_on(async {
|
||||
if let Some(listen_namespace) = listen_namespace {
|
||||
if let Err(error) = setns(listen_namespace, CloneFlags::CLONE_NEWNET) {
|
||||
error!(%error, "Failed to join listen network namespace for proxy");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let proxy = match Proxy::listen(&source) {
|
||||
Ok(proxy) => proxy,
|
||||
Err(error) => {
|
||||
error!(%error, "Failed to listen to {source}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(error) = setns(target_namespace, CloneFlags::CLONE_NEWNET) {
|
||||
error!(%error, "Failed to join target network namespace for proxy");
|
||||
return;
|
||||
}
|
||||
|
||||
// raw syscall since the libc `setuid`/`setgui` set it for the entire process
|
||||
unsafe {
|
||||
if let Err(error) = syscall!(Sysno::setgid, nobody_gid)
|
||||
.and_then(|_| syscall!(Sysno::setuid, nobody_uid))
|
||||
{
|
||||
error!(%error, "Failed drop privileges for proxy thread");
|
||||
}
|
||||
}
|
||||
|
||||
proxy.run(destination, abort_reg, run_stats).await;
|
||||
});
|
||||
});
|
||||
let child = command
|
||||
.spawn()
|
||||
.map_err(|error| ProxyError::Spawn { error })?;
|
||||
|
||||
Ok(ActiveProxy {
|
||||
source: config.source.clone(),
|
||||
destination: config.target.clone(),
|
||||
abort,
|
||||
stats,
|
||||
child: Some(child),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ActiveProxy {
|
||||
fn drop(&mut self) {
|
||||
if let ForwardSource::Unix(path) = &self.source {
|
||||
let mut child = self.child.take().unwrap();
|
||||
if let Err(error) = kill(Pid::from_raw(child.id() as i32), Some(SIGINT)) {
|
||||
error!(%error, "failed to signal proxy process to stop");
|
||||
}
|
||||
|
||||
let source = self.source.clone();
|
||||
spawn(move || {
|
||||
if let Err(error) = child.wait() {
|
||||
error!(%error, "failed to wait for proxy process");
|
||||
}
|
||||
|
||||
// we do this here, since the proxy process won't have permissions for it anymore
|
||||
if let ForwardSource::Unix(path) = &source {
|
||||
if let Err(error) = remove_file(path) {
|
||||
error!(%error, "failed to remove unix socket");
|
||||
}
|
||||
}
|
||||
self.abort.abort();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -139,53 +105,6 @@ impl PartialEq<ForwardConfig> for ActiveProxy {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
pub struct ProxyStats {
|
||||
open_connections: Arc<AtomicU64>,
|
||||
total_connections: Arc<AtomicU64>,
|
||||
/// Bytes proxied source -> destination
|
||||
bytes_written: Arc<AtomicU64>,
|
||||
/// Bytes proxied destination -> source
|
||||
bytes_read: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
impl ProxyStats {
|
||||
pub fn open_connection(&self) {
|
||||
self.open_connections.fetch_add(1, Ordering::Relaxed);
|
||||
self.total_connections.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn close_connection(&self) {
|
||||
self.open_connections.fetch_sub(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn open_connections(&self) -> u64 {
|
||||
self.open_connections.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn total_connections(&self) -> u64 {
|
||||
self.total_connections.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn add_read(&self, bytes: u64) {
|
||||
self.bytes_read.fetch_add(bytes, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn add_written(&self, bytes: u64) {
|
||||
self.bytes_written.fetch_add(bytes, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Bytes proxied destination -> source
|
||||
pub fn read(&self) -> u64 {
|
||||
self.bytes_read.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Bytes proxied source -> destination
|
||||
pub fn written(&self) -> u64 {
|
||||
self.bytes_written.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
fn open_namespace(path: impl AsRef<Path>) -> Result<File, ProxyError> {
|
||||
let path = path.as_ref();
|
||||
File::open(path).map_err(|error| ProxyError::OpenNamespace {
|
||||
|
|
@ -193,3 +112,80 @@ fn open_namespace(path: impl AsRef<Path>) -> Result<File, ProxyError> {
|
|||
path: path.into(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn proxy(
|
||||
source_namespace: NamespaceName,
|
||||
target_namespace: NamespaceName,
|
||||
source: ForwardSource,
|
||||
target: ForwardTarget,
|
||||
) -> MainResult {
|
||||
let (abort, abort_reg) = AbortHandle::new_pair();
|
||||
|
||||
let target_namespace = if target_namespace == "parent" {
|
||||
open_namespace("/proc/self/ns/net")?
|
||||
} else {
|
||||
open_namespace(format!("/var/run/netns/{target_namespace}"))?
|
||||
};
|
||||
|
||||
let listen_namespace = if source_namespace == "parent" {
|
||||
None
|
||||
} else {
|
||||
Some(open_namespace(format!(
|
||||
"/var/run/netns/{source_namespace}"
|
||||
))?)
|
||||
};
|
||||
|
||||
let nobody_uid = Uid::from(
|
||||
get_user_by_name("nobody")
|
||||
.map(|user| user.uid())
|
||||
.unwrap_or(65534),
|
||||
);
|
||||
let nobody_gid = Gid::from(
|
||||
get_group_by_name("nobody")
|
||||
.map(|group| group.gid())
|
||||
.unwrap_or(65534),
|
||||
);
|
||||
|
||||
let rt = match Builder::new_current_thread().enable_io().build() {
|
||||
Ok(rt) => rt,
|
||||
Err(error) => {
|
||||
error!(%error, "Error setting up tokio runtime");
|
||||
return Err(error.into());
|
||||
}
|
||||
};
|
||||
|
||||
rt.block_on(async {
|
||||
if let Some(listen_namespace) = listen_namespace {
|
||||
if let Err(error) = setns(listen_namespace, CloneFlags::CLONE_NEWNET) {
|
||||
error!(%error, "Failed to join listen network namespace for proxy");
|
||||
return Err(error.into());
|
||||
}
|
||||
}
|
||||
|
||||
let proxy = match Proxy::listen(&source) {
|
||||
Ok(proxy) => proxy,
|
||||
Err(error) => {
|
||||
error!(%error, "Failed to listen to {source}");
|
||||
return Err(error.into());
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(error) = setns(target_namespace, CloneFlags::CLONE_NEWNET) {
|
||||
error!(%error, "Failed to join target network namespace for proxy");
|
||||
return Err(error.into());
|
||||
}
|
||||
|
||||
if let Err(error) = setgid(nobody_gid).and_then(|_| setuid(nobody_uid)) {
|
||||
error!(%error, "Failed to drop privileges");
|
||||
}
|
||||
|
||||
tokio::spawn(async move {
|
||||
let _ = ctrl_c().await;
|
||||
abort.abort();
|
||||
});
|
||||
|
||||
proxy.run(target, abort_reg).await;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
/// Loosely based on https://github.com/fooker/netns-proxy/blob/main/src/tcp.rs
|
||||
use crate::config::{ForwardTarget, ForwardSource};
|
||||
use crate::proxy::{ProxyError, ProxyStats};
|
||||
use crate::proxy::{ProxyError};
|
||||
use futures::TryStreamExt;
|
||||
use futures::stream::{AbortRegistration, Abortable};
|
||||
use std::fs::{remove_file, set_permissions};
|
||||
|
|
@ -61,14 +61,13 @@ impl Proxy {
|
|||
})
|
||||
}
|
||||
|
||||
pub async fn run(self, target: ForwardTarget, abort: AbortRegistration, stats: ProxyStats) {
|
||||
let proxy_stats = stats.clone();
|
||||
pub async fn run(self, target: ForwardTarget, abort: AbortRegistration,) {
|
||||
match self.socket {
|
||||
ProxyListener::Tcp(socket) => {
|
||||
run_tcp(socket, target.addr, abort, proxy_stats).await
|
||||
run_tcp(socket, target.addr, abort).await
|
||||
}
|
||||
ProxyListener::Unix(socket) => {
|
||||
run_unix(socket, target.addr, abort, proxy_stats).await
|
||||
run_unix(socket, target.addr, abort).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -78,16 +77,14 @@ async fn run_tcp(
|
|||
socket: TcpListener,
|
||||
target: SocketAddr,
|
||||
abort: AbortRegistration,
|
||||
stats: ProxyStats,
|
||||
) {
|
||||
let accepts = TcpListenerStream::new(socket).map_err(|error| ProxyError::Accept { error });
|
||||
let mut accepts = pin!(Abortable::new(accepts, abort));
|
||||
while let Some(client) = accepts.next().await {
|
||||
stats.open_connection();
|
||||
let result: Result<(), ProxyError> = async {
|
||||
let client = client?;
|
||||
let remote = client.peer_addr().ok();
|
||||
proxy_stream(client, target, remote, stats.clone()).await
|
||||
proxy_stream(client, target, remote).await
|
||||
}
|
||||
.await;
|
||||
|
||||
|
|
@ -101,15 +98,13 @@ async fn run_unix(
|
|||
socket: UnixListener,
|
||||
target: SocketAddr,
|
||||
abort: AbortRegistration,
|
||||
stats: ProxyStats,
|
||||
) {
|
||||
let accepts = UnixListenerStream::new(socket).map_err(|error| ProxyError::Accept { error });
|
||||
let mut accepts = pin!(Abortable::new(accepts, abort));
|
||||
while let Some(client) = accepts.next().await {
|
||||
stats.open_connection();
|
||||
let result: Result<(), ProxyError> = async {
|
||||
let client = client?;
|
||||
proxy_stream(client, target, None, stats.clone()).await
|
||||
proxy_stream(client, target, None).await
|
||||
}
|
||||
.await;
|
||||
|
||||
|
|
@ -119,12 +114,11 @@ async fn run_unix(
|
|||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(stream, proxy))]
|
||||
#[instrument(skip(stream))]
|
||||
async fn proxy_stream<Stream>(
|
||||
mut stream: Stream,
|
||||
target: SocketAddr,
|
||||
remote: Option<SocketAddr>,
|
||||
proxy: ProxyStats,
|
||||
) -> Result<(), ProxyError>
|
||||
where
|
||||
Stream: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
|
||||
|
|
@ -143,10 +137,7 @@ where
|
|||
|
||||
spawn(async move {
|
||||
match copy_bidirectional(&mut stream, &mut upstream).await {
|
||||
Ok((written, read)) => {
|
||||
proxy.add_written(written);
|
||||
proxy.add_read(read);
|
||||
proxy.close_connection();
|
||||
Ok(_) => {
|
||||
trace!("Upstream connection closed");
|
||||
}
|
||||
Err(err) => {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue