1
0
Fork 0
mirror of https://codeberg.org/icewind/haze.git synced 2026-06-03 09:04:12 +02:00

improve websocket proxying

This commit is contained in:
Robin Appelman 2026-05-08 20:24:17 +02:00
commit ad999702aa
6 changed files with 164 additions and 43 deletions

View file

@ -5,26 +5,34 @@ use axum::http::header::HOST;
use axum::http::HeaderValue;
use axum::{
body::Body,
extract::{Request, State},
extract::Request,
response::{IntoResponse, Response},
Router,
};
use bollard::Docker;
use futures_util::StreamExt;
use hyper::body::Incoming;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::StatusCode;
use hyper_util::rt::TokioIo;
use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
use miette::{miette, IntoDiagnostic};
use std::collections::HashMap;
use std::convert::Infallible;
use std::fs::{create_dir_all, set_permissions};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::os::unix::fs::PermissionsExt;
use std::path::PathBuf;
use std::pin::pin;
use std::str::FromStr;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::UnixListener;
use tokio::signal::ctrl_c;
use tokio::spawn;
use tokio::time::sleep;
use tokio_stream::wrappers::{TcpListenerStream, UnixListenerStream};
use tracing::{debug, error, info};
struct ActiveInstances {
@ -163,20 +171,26 @@ async fn serve(instances: ActiveInstances, listen: String, base_address: String)
ctrl_c().await.ok();
};
let app = Router::new().fallback(handler).with_state(AppState {
let state = AppState {
instances: instances.clone(),
base_address: base_address.clone(),
proxy_client: Arc::new(proxy_client),
});
};
if !listen.starts_with('/') {
let addr: SocketAddr = listen.parse().into_diagnostic()?;
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
println!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app)
.with_graceful_shutdown(cancel)
.await
.unwrap();
let mut connections = pin!(TcpListenerStream::new(listener).take_until(cancel));
while let Some(stream) = connections.next().await {
match stream {
Ok(stream) => handle_connection(state.clone(), stream),
Err(error) => {
error!(%error, "connection failed");
}
}
}
} else {
let listen: PathBuf = listen.into();
if let Some(parent) = listen.parent() {
@ -187,18 +201,42 @@ async fn serve(instances: ActiveInstances, listen: String, base_address: String)
}
let _ = tokio::fs::remove_file(&listen).await;
let uds = UnixListener::bind(&listen).unwrap();
let listener = UnixListener::bind(&listen).unwrap();
println!("listening on {}", listen.display());
set_permissions(&listen, PermissionsExt::from_mode(0o666)).into_diagnostic()?;
axum::serve(uds, app)
.with_graceful_shutdown(cancel)
.await
.unwrap();
let mut connections = pin!(UnixListenerStream::new(listener).take_until(cancel));
while let Some(stream) = connections.next().await {
match stream {
Ok(stream) => handle_connection(state.clone(), stream),
Err(error) => {
error!(%error, "connection failed");
}
}
}
}
Ok(())
}
fn handle_connection<I: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
state: AppState,
stream: I,
) {
let io = TokioIo::new(stream);
// Spawn a tokio task to serve multiple connections concurrently
tokio::task::spawn(async move {
if let Err(err) = http1::Builder::new()
.serve_connection(io, service_fn(move |req| handler(state.clone(), req)))
.with_upgrades()
.await
{
eprintln!("Error serving connection: {:?}", err);
}
});
}
async fn get_remote(
host: Option<&HeaderValue>,
instances: &ActiveInstances,
@ -232,9 +270,9 @@ async fn get_remote(
}
}
type Client = hyper_util::client::legacy::Client<HttpConnector, Body>;
type Client = hyper_util::client::legacy::Client<HttpConnector, Incoming>;
async fn handler(State(state): State<AppState>, mut req: Request) -> Result<Response, StatusCode> {
async fn handler(state: AppState, mut req: Request<Incoming>) -> Result<Response, Infallible> {
let host = req.headers().get(HOST).cloned();
let remote = match get_remote(host.as_ref(), &state.instances, &state.base_address).await {
Ok(remote) => remote,
@ -259,13 +297,13 @@ async fn handler(State(state): State<AppState>, mut req: Request) -> Result<Resp
IpAddr::V4(Ipv4Addr::UNSPECIFIED),
&uri,
req,
&state.proxy_client,
state.proxy_client.as_ref(),
)
.await
{
Ok(response) => Ok(response.map(Body::new)),
Err(error) => {
error!(%error, "error while proxying request");
error!(?error, "error while proxying request");
Ok(StatusCode::BAD_REQUEST.into_response())
}
}

View file

@ -14,6 +14,7 @@ mod sftp;
mod redis;
mod sharded;
mod smb;
mod webhook;
use crate::cloud::CloudOptions;
use crate::config::{HazeConfig, Preset, ProxyConfig};
@ -32,6 +33,7 @@ use crate::service::redis::Redis;
use crate::service::sftp::Sftp;
use crate::service::sharded::{Sharding, ShardingMigrate, ShardingMigrateUnset, SingleShard};
use crate::service::smb::Smb;
use crate::service::webhook::Webhook;
use bollard::models::ContainerState;
use bollard::Docker;
use enum_dispatch::enum_dispatch;
@ -296,6 +298,8 @@ pub enum ServiceType {
RedisTls,
/// Use FrankenPHP instead of PHP-FPM
FrankenPhp,
/// Webhook test listener
Webhook,
}
#[enum_dispatch]
@ -326,6 +330,7 @@ pub enum Service {
Redis(Redis),
RedisTls(RedisTls),
FrankenPhp(FrankenPhp),
Webhook(Webhook),
Preset(PresetService),
}
@ -369,6 +374,7 @@ impl Service {
ServiceType::Redis => Some(vec![Service::Redis(Redis)]),
ServiceType::RedisTls => Some(vec![Service::RedisTls(RedisTls)]),
ServiceType::FrankenPhp => Some(vec![Service::FrankenPhp(FrankenPhp)]),
ServiceType::Webhook => Some(vec![Service::Webhook(Webhook)]),
}
} else {
presets

71
src/service/webhook.rs Normal file
View file

@ -0,0 +1,71 @@
use crate::cloud::CloudOptions;
use crate::config::HazeConfig;
use crate::image::pull_image;
use crate::service::ServiceTrait;
use crate::Result;
use bollard::models::{ContainerCreateBody, EndpointSettings, HostConfig, NetworkingConfig};
use bollard::query_parameters::CreateContainerOptions;
use bollard::Docker;
use maplit::hashmap;
use miette::IntoDiagnostic;
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Webhook;
#[async_trait::async_trait]
impl ServiceTrait for Webhook {
fn name(&self) -> &str {
"webhook"
}
async fn spawn(
&self,
docker: &Docker,
cloud_id: &str,
network: &str,
_config: &HazeConfig,
_options: &CloudOptions,
) -> Result<Vec<String>> {
let image = "ghcr.io/tarampampam/webhook-tester";
pull_image(docker, image).await?;
let options = Some(CreateContainerOptions {
name: self.container_name(cloud_id),
..CreateContainerOptions::default()
});
let config = ContainerCreateBody {
image: Some(image.into()),
host_config: Some(HostConfig {
network_mode: Some(network.to_string()),
..Default::default()
}),
labels: Some(hashmap! {
"haze-type".into() => self.name().into(),
"haze-cloud-id".into() => cloud_id.into(),
}),
networking_config: Some(NetworkingConfig {
endpoints_config: Some(hashmap! {
network.into() => EndpointSettings {
aliases: Some(vec![self.name().to_string()]),
..Default::default()
}
}),
}),
..Default::default()
};
let id = docker
.create_container(options, config)
.await
.into_diagnostic()?
.id;
docker.start_container(&id, None).await.into_diagnostic()?;
Ok(vec![id])
}
fn container_name(&self, cloud_id: &str) -> Option<String> {
Some(format!("{}-webhook", cloud_id))
}
fn proxy_port(&self) -> u16 {
8080
}
}