mirror of
https://codeberg.org/icewind/haze.git
synced 2026-06-03 17:14:08 +02:00
switch to axum based proxy
This commit is contained in:
parent
716074f343
commit
2bc9e571d2
4 changed files with 179 additions and 391 deletions
118
src/proxy.rs
118
src/proxy.rs
|
|
@ -1,28 +1,30 @@
|
|||
use crate::service::ServiceTrait;
|
||||
use crate::Result;
|
||||
use crate::{Cloud, HazeConfig};
|
||||
use axum::http::header::HOST;
|
||||
use axum::http::HeaderValue;
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Request, State},
|
||||
response::{IntoResponse, Response},
|
||||
Router,
|
||||
};
|
||||
use bollard::Docker;
|
||||
use hyper::StatusCode;
|
||||
use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
|
||||
use miette::{miette, IntoDiagnostic};
|
||||
use std::collections::HashMap;
|
||||
use std::convert::Infallible;
|
||||
use std::fs::{create_dir_all, remove_file, set_permissions};
|
||||
use std::fs::{create_dir_all, set_permissions};
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Duration;
|
||||
use tokio::net::{UnixListener, UnixStream};
|
||||
use tokio::net::UnixListener;
|
||||
use tokio::signal::ctrl_c;
|
||||
use tokio::spawn;
|
||||
use tokio::time::sleep;
|
||||
use tokio_stream::wrappers::UnixListenerStream;
|
||||
use tracing::info;
|
||||
use warp::http::header::HOST;
|
||||
use warp::http::HeaderValue;
|
||||
use warp::hyper::server::accept::from_stream;
|
||||
use warp::hyper::server::conn::AddrStream;
|
||||
use warp::hyper::service::{make_service_fn, service_fn};
|
||||
use warp::hyper::{Body, Request, Response, Server, StatusCode};
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
struct ActiveInstances {
|
||||
known: Mutex<HashMap<String, SocketAddr>>,
|
||||
|
|
@ -101,11 +103,22 @@ pub async fn proxy(docker: Docker, config: HazeConfig) -> Result<()> {
|
|||
serve(instances, listen, base_address).await
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AppState {
|
||||
instances: Arc<ActiveInstances>,
|
||||
base_address: Arc<String>,
|
||||
proxy_client: Arc<Client>,
|
||||
}
|
||||
|
||||
async fn serve(instances: ActiveInstances, listen: String, base_address: String) -> Result<()> {
|
||||
let instances = Arc::new(instances);
|
||||
let base_address = Arc::new(base_address);
|
||||
let last_instances = instances.clone();
|
||||
|
||||
let proxy_client: Client =
|
||||
hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
|
||||
.build(HttpConnector::new());
|
||||
|
||||
spawn(async move {
|
||||
loop {
|
||||
sleep(Duration::from_secs(1)).await;
|
||||
|
|
@ -117,27 +130,21 @@ async fn serve(instances: ActiveInstances, listen: String, base_address: String)
|
|||
ctrl_c().await.ok();
|
||||
};
|
||||
|
||||
let handler = move |remote_addr| {
|
||||
let instances = instances.clone();
|
||||
let base_address = base_address.clone();
|
||||
async move {
|
||||
Ok::<_, Infallible>(service_fn(move |req| {
|
||||
handle(remote_addr, req, instances.clone(), base_address.clone())
|
||||
}))
|
||||
}
|
||||
};
|
||||
let app = Router::new().fallback(handler).with_state(AppState {
|
||||
instances: instances.clone(),
|
||||
base_address: base_address.clone(),
|
||||
proxy_client: Arc::new(proxy_client),
|
||||
});
|
||||
|
||||
if !listen.starts_with('/') {
|
||||
let make_svc = make_service_fn(|conn: &AddrStream| handler(conn.remote_addr().ip()));
|
||||
let addr: SocketAddr = listen.parse().into_diagnostic()?;
|
||||
Server::bind(&addr)
|
||||
.serve(make_svc)
|
||||
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
|
||||
println!("listening on {}", listener.local_addr().unwrap());
|
||||
axum::serve(listener, app)
|
||||
.with_graceful_shutdown(cancel)
|
||||
.await
|
||||
.into_diagnostic()?;
|
||||
.unwrap();
|
||||
} else {
|
||||
let make_svc =
|
||||
make_service_fn(move |_conn: &UnixStream| handler(Ipv4Addr::UNSPECIFIED.into()));
|
||||
let listen: PathBuf = listen.into();
|
||||
if let Some(parent) = listen.parent() {
|
||||
if !parent.exists() {
|
||||
|
|
@ -145,17 +152,15 @@ async fn serve(instances: ActiveInstances, listen: String, base_address: String)
|
|||
set_permissions(parent, PermissionsExt::from_mode(0o755)).into_diagnostic()?;
|
||||
}
|
||||
}
|
||||
remove_file(&listen).ok();
|
||||
let _ = tokio::fs::remove_file(&listen).await;
|
||||
|
||||
let listener = UnixListener::bind(&listen).into_diagnostic()?;
|
||||
let uds = UnixListener::bind(&listen).unwrap();
|
||||
set_permissions(&listen, PermissionsExt::from_mode(0o666)).into_diagnostic()?;
|
||||
let stream = UnixListenerStream::new(listener);
|
||||
let acceptor = from_stream(stream);
|
||||
Server::builder(acceptor)
|
||||
.serve(make_svc)
|
||||
|
||||
axum::serve(uds, app)
|
||||
.with_graceful_shutdown(cancel)
|
||||
.await
|
||||
.into_diagnostic()?;
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
|
@ -190,30 +195,41 @@ async fn get_remote(
|
|||
}
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
client_ip: IpAddr,
|
||||
req: Request<Body>,
|
||||
instances: Arc<ActiveInstances>,
|
||||
base_address: Arc<String>,
|
||||
) -> Result<Response<Body>, Infallible> {
|
||||
let host = req.headers().get(HOST);
|
||||
let remote = match get_remote(host, &instances, &base_address).await {
|
||||
type Client = hyper_util::client::legacy::Client<HttpConnector, Body>;
|
||||
|
||||
async fn handler(State(state): State<AppState>, mut req: Request) -> Result<Response, StatusCode> {
|
||||
let host = req.headers().get(HOST).cloned();
|
||||
let remote = match get_remote(host.as_ref(), &state.instances, &state.base_address).await {
|
||||
Ok(remote) => remote,
|
||||
Err(e) => {
|
||||
return Ok(Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
return Ok(hyper::Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body(e.into())
|
||||
.unwrap())
|
||||
}
|
||||
};
|
||||
|
||||
let forward = format!("http://{}", remote);
|
||||
let client = hyper::Client::builder().build_http();
|
||||
match hyper_reverse_proxy::call(client_ip, &forward, req, &client).await {
|
||||
Ok(response) => Ok(response),
|
||||
Err(_error) => Ok(Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(Body::empty())
|
||||
.unwrap()),
|
||||
let uri = format!("http://{remote}");
|
||||
debug!(target = uri, "proxying request");
|
||||
|
||||
// fix weird duplicate host header
|
||||
req.headers_mut().remove(HOST);
|
||||
if let Some(host) = host {
|
||||
req.headers_mut().insert(HOST, host.clone());
|
||||
}
|
||||
|
||||
match hyper_reverse_proxy::call(
|
||||
IpAddr::V4(Ipv4Addr::UNSPECIFIED),
|
||||
&uri,
|
||||
req,
|
||||
&state.proxy_client,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(response) => Ok(response.map(|body| Body::new(body))),
|
||||
Err(error) => {
|
||||
error!(%error, "error while proxying request");
|
||||
Ok(StatusCode::BAD_REQUEST.into_response())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue