move proxies to sub-processes

This commit is contained in:
Robin Appelman 2025-11-12 21:01:46 +01:00
commit e672e11f09
9 changed files with 217 additions and 199 deletions

38
Cargo.lock generated
View file

@ -433,15 +433,6 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "humansize"
version = "2.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6cb51c9a029ddc91b07a787f1d86b53ccfa49b0e86688c946ebe8d3555685dd7"
dependencies = [
"libm",
]
[[package]] [[package]]
name = "ident_case" name = "ident_case"
version = "1.0.1" version = "1.0.1"
@ -476,12 +467,6 @@ version = "0.2.177"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976"
[[package]]
name = "libm"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de"
[[package]] [[package]]
name = "lock_api" name = "lock_api"
version = "0.4.14" version = "0.4.14"
@ -579,14 +564,12 @@ dependencies = [
"either", "either",
"futures", "futures",
"futures-concurrency", "futures-concurrency",
"humansize",
"main_error", "main_error",
"neli", "neli",
"nix", "nix",
"sd-notify", "sd-notify",
"serde", "serde",
"serde_test", "serde_test",
"syscalls",
"thiserror", "thiserror",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
@ -813,17 +796,6 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "serde_repr"
version = "0.1.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "serde_spanned" name = "serde_spanned"
version = "1.0.3" version = "1.0.3"
@ -911,16 +883,6 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "syscalls"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90db46b5b4962319605d435986c775ea45a0ad2561c09e1d5372b89afeb49cf4"
dependencies = [
"serde",
"serde_repr",
]
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "2.0.17" version = "2.0.17"

View file

@ -14,15 +14,13 @@ thiserror = "2.0.17"
tracing = "0.1.41" tracing = "0.1.41"
tracing-subscriber = "0.3.20" tracing-subscriber = "0.3.20"
main_error = "0.1.2" main_error = "0.1.2"
nix = { version = "0.30.1", features = ["mount", "sched"] } nix = { version = "0.30.1", features = ["mount", "sched", "user", "signal"] }
sd-notify = "0.4.5" sd-notify = "0.4.5"
futures = "0.3.31" futures = "0.3.31"
futures-concurrency = "7.6.3" futures-concurrency = "7.6.3"
humansize = { version = "2.1.3", features = ["no_alloc"] }
neli = "0.7.1" neli = "0.7.1"
either = "1.15.0" either = "1.15.0"
uzers = "0.12.1" uzers = "0.12.1"
syscalls = "0.7.0"
[dev-dependencies] [dev-dependencies]
serde_test = "1.0.177" serde_test = "1.0.177"

View file

@ -3,6 +3,8 @@ use serde::{Deserialize, Deserializer};
use std::ffi::OsString; use std::ffi::OsString;
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
use std::path::Path; use std::path::Path;
use std::str::FromStr;
use thiserror::Error;
#[derive(Debug, Clone, Eq, PartialEq, Hash)] #[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct NamespaceName(String); pub struct NamespaceName(String);
@ -38,6 +40,12 @@ impl AsRef<Path> for NamespaceName {
} }
} }
impl PartialEq<&str> for NamespaceName {
fn eq(&self, other: &&str) -> bool {
self.0 == *other
}
}
impl From<NamespaceName> for String { impl From<NamespaceName> for String {
fn from(value: NamespaceName) -> Self { fn from(value: NamespaceName) -> Self {
value.0 value.0
@ -83,6 +91,23 @@ impl<'de> Deserialize<'de> for NamespaceName {
} }
} }
impl FromStr for NamespaceName {
type Err = InvalidNamespaceNameError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if !validate_name(s) {
return Err(InvalidNamespaceNameError { name: s.into() });
}
Ok(NamespaceName(s.into()))
}
}
#[derive(Debug, Error)]
#[error("invalid name for namespace: '{name}'")]
pub struct InvalidNamespaceNameError {
name: String,
}
/// Check if a name follows the portable filename character set /// Check if a name follows the portable filename character set
fn validate_name(name: &str) -> bool { fn validate_name(name: &str) -> bool {
if name.is_empty() { if name.is_empty() {

View file

@ -4,6 +4,7 @@ use std::fmt::{Display, Formatter};
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::path::PathBuf; use std::path::PathBuf;
use std::str::FromStr; use std::str::FromStr;
use thiserror::Error;
#[derive(Debug, PartialEq, Clone, Hash, Eq)] #[derive(Debug, PartialEq, Clone, Hash, Eq)]
pub enum ForwardSource { pub enum ForwardSource {
@ -77,17 +78,8 @@ impl<'de> Deserialize<'de> for ForwardSource {
where where
E: Error, E: Error,
{ {
if v.starts_with('/') { v.parse()
Ok(ForwardSource::Unix(v.into())) .map_err(|_| E::invalid_value(Unexpected::Str(v), &self))
} else {
if let Ok(port) = u16::from_str(v) {
return self.visit_u16(port);
}
let addr = v
.parse()
.map_err(|_| E::invalid_value(Unexpected::Str(v), &self))?;
Ok(ForwardSource::Ip(addr))
}
} }
} }
@ -95,6 +87,31 @@ impl<'de> Deserialize<'de> for ForwardSource {
} }
} }
impl FromStr for ForwardSource {
type Err = InvalidForwardSource;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.starts_with('/') {
Ok(ForwardSource::Unix(s.into()))
} else {
if let Ok(port) = u16::from_str(s) {
let ip = IpAddr::from([0, 0, 0, 0]);
return Ok(ForwardSource::Ip(SocketAddr::from((ip, port))));
}
let addr = s
.parse()
.map_err(|_| InvalidForwardSource { forward: s.into() })?;
Ok(ForwardSource::Ip(addr))
}
}
}
#[derive(Debug, Error)]
#[error("forward source '{forward}' is not a valid unix path or socket address")]
pub struct InvalidForwardSource {
forward: String,
}
#[test] #[test]
fn test_de() { fn test_de() {
use serde_test::{assert_de_tokens, assert_de_tokens_error, Token}; use serde_test::{assert_de_tokens, assert_de_tokens_error, Token};

View file

@ -3,6 +3,7 @@ use serde::{Deserialize, Deserializer};
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::str::FromStr; use std::str::FromStr;
use thiserror::Error;
#[derive(Debug, PartialEq, Clone, Hash, Eq)] #[derive(Debug, PartialEq, Clone, Hash, Eq)]
pub struct ForwardTarget { pub struct ForwardTarget {
@ -70,13 +71,8 @@ impl<'de> Deserialize<'de> for ForwardTarget {
where where
E: Error, E: Error,
{ {
if let Ok(port) = u16::from_str(v) { v.parse()
return self.visit_u16(port); .map_err(|_| E::invalid_value(Unexpected::Str(v), &self))
}
let addr = v
.parse()
.map_err(|_| E::invalid_value(Unexpected::Str(v), &self))?;
Ok(ForwardTarget { addr })
} }
} }
@ -84,6 +80,29 @@ impl<'de> Deserialize<'de> for ForwardTarget {
} }
} }
impl FromStr for ForwardTarget {
type Err = InvalidForwardTarget;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Ok(port) = u16::from_str(s) {
let ip = IpAddr::from([127, 0, 0, 1]);
return Ok(ForwardTarget {
addr: SocketAddr::from((ip, port)),
});
}
let addr = s
.parse()
.map_err(|_| InvalidForwardTarget { forward: s.into() })?;
Ok(ForwardTarget { addr })
}
}
#[derive(Debug, Error)]
#[error("forward source '{forward}' is not a valid socket address")]
pub struct InvalidForwardTarget {
forward: String,
}
#[test] #[test]
fn test_de() { fn test_de() {
use serde_test::{assert_de_tokens, assert_de_tokens_error, Token}; use serde_test::{assert_de_tokens, assert_de_tokens_error, Token};

View file

@ -4,7 +4,6 @@ use crate::proxy::{ActiveProxy, ProxyError};
use futures::FutureExt; use futures::FutureExt;
use futures::StreamExt; use futures::StreamExt;
use futures_concurrency::stream::Merge; use futures_concurrency::stream::Merge;
use humansize::{BINARY, SizeFormatter};
use main_error::MainResult; use main_error::MainResult;
use sd_notify::{NotifyState, notify}; use sd_notify::{NotifyState, notify};
use std::io::Error as IoError; use std::io::Error as IoError;
@ -78,13 +77,9 @@ async fn daemon_async(mut config: Config) -> Result<(), DaemonError> {
println!("{}:", namespace.name()); println!("{}:", namespace.name());
for proxy in &namespace.proxies { for proxy in &namespace.proxies {
println!( println!(
" {} => {} {} connections ({} active), {} sent to namespace, {} received from namespace", " {} => {}",
proxy.source, proxy.source,
proxy.destination, proxy.destination,
proxy.stats.total_connections(),
proxy.stats.open_connections(),
SizeFormatter::new(proxy.stats.written(), BINARY),
SizeFormatter::new(proxy.stats.read(), BINARY),
); );
} }
} }

View file

@ -1,9 +1,10 @@
use std::path::{PathBuf}; use std::path::{PathBuf};
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use main_error::MainResult; use main_error::MainResult;
use crate::config::Config; use crate::config::{Config, ForwardSource, ForwardTarget, NamespaceName};
use crate::daemon::daemon; use crate::daemon::daemon;
use crate::down::down; use crate::down::down;
use crate::proxy::proxy;
use crate::up::up; use crate::up::up;
mod config; mod config;
@ -38,6 +39,17 @@ enum Commands {
Down, Down,
/// Signal a running daemon to reload it's configuration /// Signal a running daemon to reload it's configuration
Reload, Reload,
/// Create the configured namespaces
Proxy {
/// Namespace to proxy connections from, use "parent" to listen in the namespace this command was spawned into
source_namespace: NamespaceName,
/// Source address to listen on
source: ForwardSource,
/// Namespace to proxy connections to
target_namespace: NamespaceName,
/// Target address to connect to
target: ForwardTarget,
},
} }
fn main() -> MainResult { fn main() -> MainResult {
@ -56,7 +68,10 @@ fn main() -> MainResult {
Commands::Down => { Commands::Down => {
down() down()
} }
Commands::Reload => reload() Commands::Reload => reload(),
Commands::Proxy {source, target, source_namespace, target_namespace} => {
proxy(source_namespace, target_namespace, source, target)
},
} }
} }

View file

@ -3,17 +3,19 @@ mod tcp;
use crate::config::{ForwardConfig, ForwardSource, ForwardTarget, NamespaceName}; use crate::config::{ForwardConfig, ForwardSource, ForwardTarget, NamespaceName};
use crate::proxy::tcp::Proxy; use crate::proxy::tcp::Proxy;
use futures::future::AbortHandle; use futures::future::AbortHandle;
use main_error::MainResult;
use nix::sched::{CloneFlags, setns}; use nix::sched::{CloneFlags, setns};
use nix::sys::signal::{SIGINT, kill};
use nix::unistd::{Gid, Pid, Uid, setgid, setuid};
use std::fs::{File, remove_file}; use std::fs::{File, remove_file};
use std::io::Error as IoError; use std::io::Error as IoError;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::Arc; use std::process::{Child, Command};
use std::sync::atomic::{AtomicU64, Ordering};
use std::thread::spawn; use std::thread::spawn;
use syscalls::{Sysno, syscall};
use thiserror::Error; use thiserror::Error;
use tokio::runtime::Builder; use tokio::runtime::Builder;
use tokio::signal::ctrl_c;
use tracing::error; use tracing::error;
use uzers::{get_group_by_name, get_user_by_name}; use uzers::{get_group_by_name, get_user_by_name};
@ -33,13 +35,14 @@ pub enum ProxyError {
}, },
#[error("Failed to open namespace file {}: {error:#}", path.display())] #[error("Failed to open namespace file {}: {error:#}", path.display())]
OpenNamespace { path: PathBuf, error: IoError }, OpenNamespace { path: PathBuf, error: IoError },
#[error("Failed to spawn proxy: {error:#}")]
Spawn { error: std::io::Error },
} }
pub struct ActiveProxy { pub struct ActiveProxy {
pub source: ForwardSource, pub source: ForwardSource,
pub destination: ForwardTarget, pub destination: ForwardTarget,
abort: AbortHandle, child: Option<Child>,
pub stats: ProxyStats,
} }
impl ActiveProxy { impl ActiveProxy {
@ -47,89 +50,52 @@ impl ActiveProxy {
config: &ForwardConfig, config: &ForwardConfig,
namespace: &NamespaceName, namespace: &NamespaceName,
) -> Result<ActiveProxy, ProxyError> { ) -> Result<ActiveProxy, ProxyError> {
let stats = ProxyStats::default();
let (abort, abort_reg) = AbortHandle::new_pair();
let destination = config.target.clone();
let run_stats = stats.clone();
let ns_handle = open_namespace(format!("/var/run/netns/{namespace}"))?;
let self_ns_handle = open_namespace("/proc/self/ns/net")?;
let (listen_namespace, target_namespace) = if config.reverse { let (listen_namespace, target_namespace) = if config.reverse {
(Some(ns_handle), self_ns_handle) (namespace.as_ref(), "parent")
} else { } else {
(None, ns_handle) ("parent", namespace.as_ref())
}; };
let nobody_uid = get_user_by_name("nobody") let mut command = Command::new("/proc/self/exe");
.map(|user| user.uid()) command
.unwrap_or(65534); .arg("proxy")
let nobody_gid = get_group_by_name("nobody") .arg(listen_namespace)
.map(|group| group.gid()) .arg(config.source.to_string())
.unwrap_or(65534); .arg(target_namespace)
.arg(config.target.to_string());
let source = config.source.clone(); let child = command
spawn(move || { .spawn()
let rt = match Builder::new_current_thread().enable_io().build() { .map_err(|error| ProxyError::Spawn { error })?;
Ok(rt) => rt,
Err(error) => {
error!(%error, "Error setting up tokio runtime");
return;
}
};
rt.block_on(async {
if let Some(listen_namespace) = listen_namespace {
if let Err(error) = setns(listen_namespace, CloneFlags::CLONE_NEWNET) {
error!(%error, "Failed to join listen network namespace for proxy");
return;
}
}
let proxy = match Proxy::listen(&source) {
Ok(proxy) => proxy,
Err(error) => {
error!(%error, "Failed to listen to {source}");
return;
}
};
if let Err(error) = setns(target_namespace, CloneFlags::CLONE_NEWNET) {
error!(%error, "Failed to join target network namespace for proxy");
return;
}
// raw syscall since the libc `setuid`/`setgui` set it for the entire process
unsafe {
if let Err(error) = syscall!(Sysno::setgid, nobody_gid)
.and_then(|_| syscall!(Sysno::setuid, nobody_uid))
{
error!(%error, "Failed drop privileges for proxy thread");
}
}
proxy.run(destination, abort_reg, run_stats).await;
});
});
Ok(ActiveProxy { Ok(ActiveProxy {
source: config.source.clone(), source: config.source.clone(),
destination: config.target.clone(), destination: config.target.clone(),
abort, child: Some(child),
stats,
}) })
} }
} }
impl Drop for ActiveProxy { impl Drop for ActiveProxy {
fn drop(&mut self) { fn drop(&mut self) {
if let ForwardSource::Unix(path) = &self.source { let mut child = self.child.take().unwrap();
if let Err(error) = kill(Pid::from_raw(child.id() as i32), Some(SIGINT)) {
error!(%error, "failed to signal proxy process to stop");
}
let source = self.source.clone();
spawn(move || {
if let Err(error) = child.wait() {
error!(%error, "failed to wait for proxy process");
}
// we do this here, since the proxy process won't have permissions for it anymore
if let ForwardSource::Unix(path) = &source {
if let Err(error) = remove_file(path) { if let Err(error) = remove_file(path) {
error!(%error, "failed to remove unix socket"); error!(%error, "failed to remove unix socket");
} }
} }
self.abort.abort(); });
} }
} }
@ -139,53 +105,6 @@ impl PartialEq<ForwardConfig> for ActiveProxy {
} }
} }
#[derive(Default, Clone)]
pub struct ProxyStats {
open_connections: Arc<AtomicU64>,
total_connections: Arc<AtomicU64>,
/// Bytes proxied source -> destination
bytes_written: Arc<AtomicU64>,
/// Bytes proxied destination -> source
bytes_read: Arc<AtomicU64>,
}
impl ProxyStats {
pub fn open_connection(&self) {
self.open_connections.fetch_add(1, Ordering::Relaxed);
self.total_connections.fetch_add(1, Ordering::Relaxed);
}
pub fn close_connection(&self) {
self.open_connections.fetch_sub(1, Ordering::Relaxed);
}
pub fn open_connections(&self) -> u64 {
self.open_connections.load(Ordering::Relaxed)
}
pub fn total_connections(&self) -> u64 {
self.total_connections.load(Ordering::Relaxed)
}
pub fn add_read(&self, bytes: u64) {
self.bytes_read.fetch_add(bytes, Ordering::Relaxed);
}
pub fn add_written(&self, bytes: u64) {
self.bytes_written.fetch_add(bytes, Ordering::Relaxed);
}
/// Bytes proxied destination -> source
pub fn read(&self) -> u64 {
self.bytes_read.load(Ordering::Relaxed)
}
/// Bytes proxied source -> destination
pub fn written(&self) -> u64 {
self.bytes_written.load(Ordering::Relaxed)
}
}
fn open_namespace(path: impl AsRef<Path>) -> Result<File, ProxyError> { fn open_namespace(path: impl AsRef<Path>) -> Result<File, ProxyError> {
let path = path.as_ref(); let path = path.as_ref();
File::open(path).map_err(|error| ProxyError::OpenNamespace { File::open(path).map_err(|error| ProxyError::OpenNamespace {
@ -193,3 +112,80 @@ fn open_namespace(path: impl AsRef<Path>) -> Result<File, ProxyError> {
path: path.into(), path: path.into(),
}) })
} }
pub fn proxy(
source_namespace: NamespaceName,
target_namespace: NamespaceName,
source: ForwardSource,
target: ForwardTarget,
) -> MainResult {
let (abort, abort_reg) = AbortHandle::new_pair();
let target_namespace = if target_namespace == "parent" {
open_namespace("/proc/self/ns/net")?
} else {
open_namespace(format!("/var/run/netns/{target_namespace}"))?
};
let listen_namespace = if source_namespace == "parent" {
None
} else {
Some(open_namespace(format!(
"/var/run/netns/{source_namespace}"
))?)
};
let nobody_uid = Uid::from(
get_user_by_name("nobody")
.map(|user| user.uid())
.unwrap_or(65534),
);
let nobody_gid = Gid::from(
get_group_by_name("nobody")
.map(|group| group.gid())
.unwrap_or(65534),
);
let rt = match Builder::new_current_thread().enable_io().build() {
Ok(rt) => rt,
Err(error) => {
error!(%error, "Error setting up tokio runtime");
return Err(error.into());
}
};
rt.block_on(async {
if let Some(listen_namespace) = listen_namespace {
if let Err(error) = setns(listen_namespace, CloneFlags::CLONE_NEWNET) {
error!(%error, "Failed to join listen network namespace for proxy");
return Err(error.into());
}
}
let proxy = match Proxy::listen(&source) {
Ok(proxy) => proxy,
Err(error) => {
error!(%error, "Failed to listen to {source}");
return Err(error.into());
}
};
if let Err(error) = setns(target_namespace, CloneFlags::CLONE_NEWNET) {
error!(%error, "Failed to join target network namespace for proxy");
return Err(error.into());
}
if let Err(error) = setgid(nobody_gid).and_then(|_| setuid(nobody_uid)) {
error!(%error, "Failed to drop privileges");
}
tokio::spawn(async move {
let _ = ctrl_c().await;
abort.abort();
});
proxy.run(target, abort_reg).await;
Ok(())
})
}

View file

@ -1,6 +1,6 @@
/// Loosely based on https://github.com/fooker/netns-proxy/blob/main/src/tcp.rs /// Loosely based on https://github.com/fooker/netns-proxy/blob/main/src/tcp.rs
use crate::config::{ForwardTarget, ForwardSource}; use crate::config::{ForwardTarget, ForwardSource};
use crate::proxy::{ProxyError, ProxyStats}; use crate::proxy::{ProxyError};
use futures::TryStreamExt; use futures::TryStreamExt;
use futures::stream::{AbortRegistration, Abortable}; use futures::stream::{AbortRegistration, Abortable};
use std::fs::{remove_file, set_permissions}; use std::fs::{remove_file, set_permissions};
@ -61,14 +61,13 @@ impl Proxy {
}) })
} }
pub async fn run(self, target: ForwardTarget, abort: AbortRegistration, stats: ProxyStats) { pub async fn run(self, target: ForwardTarget, abort: AbortRegistration,) {
let proxy_stats = stats.clone();
match self.socket { match self.socket {
ProxyListener::Tcp(socket) => { ProxyListener::Tcp(socket) => {
run_tcp(socket, target.addr, abort, proxy_stats).await run_tcp(socket, target.addr, abort).await
} }
ProxyListener::Unix(socket) => { ProxyListener::Unix(socket) => {
run_unix(socket, target.addr, abort, proxy_stats).await run_unix(socket, target.addr, abort).await
} }
} }
} }
@ -78,16 +77,14 @@ async fn run_tcp(
socket: TcpListener, socket: TcpListener,
target: SocketAddr, target: SocketAddr,
abort: AbortRegistration, abort: AbortRegistration,
stats: ProxyStats,
) { ) {
let accepts = TcpListenerStream::new(socket).map_err(|error| ProxyError::Accept { error }); let accepts = TcpListenerStream::new(socket).map_err(|error| ProxyError::Accept { error });
let mut accepts = pin!(Abortable::new(accepts, abort)); let mut accepts = pin!(Abortable::new(accepts, abort));
while let Some(client) = accepts.next().await { while let Some(client) = accepts.next().await {
stats.open_connection();
let result: Result<(), ProxyError> = async { let result: Result<(), ProxyError> = async {
let client = client?; let client = client?;
let remote = client.peer_addr().ok(); let remote = client.peer_addr().ok();
proxy_stream(client, target, remote, stats.clone()).await proxy_stream(client, target, remote).await
} }
.await; .await;
@ -101,15 +98,13 @@ async fn run_unix(
socket: UnixListener, socket: UnixListener,
target: SocketAddr, target: SocketAddr,
abort: AbortRegistration, abort: AbortRegistration,
stats: ProxyStats,
) { ) {
let accepts = UnixListenerStream::new(socket).map_err(|error| ProxyError::Accept { error }); let accepts = UnixListenerStream::new(socket).map_err(|error| ProxyError::Accept { error });
let mut accepts = pin!(Abortable::new(accepts, abort)); let mut accepts = pin!(Abortable::new(accepts, abort));
while let Some(client) = accepts.next().await { while let Some(client) = accepts.next().await {
stats.open_connection();
let result: Result<(), ProxyError> = async { let result: Result<(), ProxyError> = async {
let client = client?; let client = client?;
proxy_stream(client, target, None, stats.clone()).await proxy_stream(client, target, None).await
} }
.await; .await;
@ -119,12 +114,11 @@ async fn run_unix(
} }
} }
#[instrument(skip(stream, proxy))] #[instrument(skip(stream))]
async fn proxy_stream<Stream>( async fn proxy_stream<Stream>(
mut stream: Stream, mut stream: Stream,
target: SocketAddr, target: SocketAddr,
remote: Option<SocketAddr>, remote: Option<SocketAddr>,
proxy: ProxyStats,
) -> Result<(), ProxyError> ) -> Result<(), ProxyError>
where where
Stream: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, Stream: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
@ -143,10 +137,7 @@ where
spawn(async move { spawn(async move {
match copy_bidirectional(&mut stream, &mut upstream).await { match copy_bidirectional(&mut stream, &mut upstream).await {
Ok((written, read)) => { Ok(_) => {
proxy.add_written(written);
proxy.add_read(read);
proxy.close_connection();
trace!("Upstream connection closed"); trace!("Upstream connection closed");
} }
Err(err) => { Err(err) => {