minor cleanup

This commit is contained in:
Robin Appelman 2021-07-29 16:45:38 +02:00
commit 3f6837c87c
2 changed files with 68 additions and 25 deletions

View file

@ -1,4 +1,5 @@
use rfc7239::{parse, Forwarded, NodeIdentifier, NodeName}; use rfc7239::{parse, Forwarded, NodeIdentifier, NodeName};
use std::borrow::Cow;
use std::convert::Infallible; use std::convert::Infallible;
use std::iter::once; use std::iter::once;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
@ -11,7 +12,7 @@ use warp::Filter;
/// This uses the "x-forwarded-for" or "x-real-ip" headers set by reverse proxies. /// This uses the "x-forwarded-for" or "x-real-ip" headers set by reverse proxies.
/// To stop clients from abusing these headers, only headers set by trusted remotes will be accepted. /// To stop clients from abusing these headers, only headers set by trusted remotes will be accepted.
/// ///
/// Note that if multiple forwarded-for addresses are present, wich can be the case when using nested reverse proxies, /// Note that if multiple forwarded-for addresses are present, which can be the case when using nested reverse proxies,
/// all proxies in the chain have to be within the list of trusted proxies. /// all proxies in the chain have to be within the list of trusted proxies.
/// ///
/// ## Example /// ## Example
@ -50,10 +51,10 @@ pub fn real_ip(
pub fn get_forwarded_for() -> impl Filter<Extract = (Vec<IpAddr>,), Error = Infallible> + Clone { pub fn get_forwarded_for() -> impl Filter<Extract = (Vec<IpAddr>,), Error = Infallible> + Clone {
warp::header("x-forwarded-for") warp::header("x-forwarded-for")
.map(|list: CommaSeparated<IpAddr>| list.into_inner()) .map(|list: CommaSeparated<IpAddr>| list.into_inner())
.or(warp::header("x-real-ip").map( .or(warp::header("x-real-ip").map(|ip: String| {
|ip: String| IpAddr::from_str(maybe_bracketed(&maybe_quoted(ip))) IpAddr::from_str(maybe_bracketed(&maybe_quoted(&ip)))
.map_or_else(|_| Vec::<IpAddr>::new(), |x| vec![x]) .map_or_else(|_| Vec::<IpAddr>::new(), |x| vec![x])
)) }))
.unify() .unify()
.or(warp::header("forwarded").map(|header: String| { .or(warp::header("forwarded").map(|header: String| {
parse(&header) parse(&header)
@ -80,7 +81,7 @@ enum CommaSeparatedIteratorState {
Quoted, Quoted,
QuotedPair, QuotedPair,
Token, Token,
PostambleForQuoted, PostAmbleForQuoted,
} }
struct CommaSeparatedIterator<'a> { struct CommaSeparatedIterator<'a> {
@ -97,7 +98,7 @@ struct CommaSeparatedIterator<'a> {
impl<'a> CommaSeparatedIterator<'a> { impl<'a> CommaSeparatedIterator<'a> {
pub fn new(target: &'a str) -> Self { pub fn new(target: &'a str) -> Self {
Self { Self {
target: target, target,
char_indices: target.char_indices(), char_indices: target.char_indices(),
state: CommaSeparatedIteratorState::Default, state: CommaSeparatedIteratorState::Default,
s: 0, s: 0,
@ -130,7 +131,7 @@ impl<'a> Iterator for CommaSeparatedIterator<'a> {
CommaSeparatedIteratorState::Quoted => match c { CommaSeparatedIteratorState::Quoted => match c {
'"' => ( '"' => (
Some(Some(&self.target[self.s..i + 1])), Some(Some(&self.target[self.s..i + 1])),
CommaSeparatedIteratorState::PostambleForQuoted, CommaSeparatedIteratorState::PostAmbleForQuoted,
), ),
'\\' => (None, CommaSeparatedIteratorState::QuotedPair), '\\' => (None, CommaSeparatedIteratorState::QuotedPair),
_ => (None, CommaSeparatedIteratorState::Quoted), _ => (None, CommaSeparatedIteratorState::Quoted),
@ -145,9 +146,9 @@ impl<'a> Iterator for CommaSeparatedIterator<'a> {
), ),
_ => (None, CommaSeparatedIteratorState::Token), _ => (None, CommaSeparatedIteratorState::Token),
}, },
CommaSeparatedIteratorState::PostambleForQuoted => match c { CommaSeparatedIteratorState::PostAmbleForQuoted => match c {
',' => (None, CommaSeparatedIteratorState::Default), ',' => (None, CommaSeparatedIteratorState::Default),
_ => (None, CommaSeparatedIteratorState::PostambleForQuoted), _ => (None, CommaSeparatedIteratorState::PostAmbleForQuoted),
}, },
} { } {
(Some(next), next_state) => { (Some(next), next_state) => {
@ -163,7 +164,7 @@ impl<'a> Iterator for CommaSeparatedIterator<'a> {
} }
return match self.state { return match self.state {
CommaSeparatedIteratorState::Default CommaSeparatedIteratorState::Default
| CommaSeparatedIteratorState::PostambleForQuoted => None, | CommaSeparatedIteratorState::PostAmbleForQuoted => None,
CommaSeparatedIteratorState::Quoted | CommaSeparatedIteratorState::QuotedPair => { CommaSeparatedIteratorState::Quoted | CommaSeparatedIteratorState::QuotedPair => {
self.state = CommaSeparatedIteratorState::Default; self.state = CommaSeparatedIteratorState::Default;
Some(&self.target[self.s..]) Some(&self.target[self.s..])
@ -176,35 +177,39 @@ impl<'a> Iterator for CommaSeparatedIterator<'a> {
} }
} }
pub fn maybe_quoted<T: AsRef<str>>(x: T) -> String { enum EscapeState {
let x = x.as_ref(); Normal,
Escaped,
}
fn maybe_quoted(x: &str) -> Cow<str> {
let mut i = x.chars(); let mut i = x.chars();
if i.next() == Some('"') { if i.next() == Some('"') {
let mut s = String::with_capacity(x.len()); let mut s = String::with_capacity(x.len());
let mut state = 0; let mut state = EscapeState::Normal;
for c in i { for c in i {
state = match state { state = match state {
0 => match c { EscapeState::Normal => match c {
'"' => break, '"' => break,
'\\' => 1, '\\' => EscapeState::Escaped,
_ => { _ => {
s.push(c); s.push(c);
0 EscapeState::Normal
} }
}, },
_ => { EscapeState::Escaped => {
s.push(c); s.push(c);
0 EscapeState::Normal
} }
}; };
} }
s s.into()
} else { } else {
x.to_string() x.into()
} }
} }
pub fn maybe_bracketed<'a>(x: &'a str) -> &'a str { fn maybe_bracketed(x: &str) -> &str {
if x.as_bytes()[0] == ('[' as u8) && x.as_bytes()[x.len() - 1] == (']' as u8) { if x.as_bytes()[0] == ('[' as u8) && x.as_bytes()[x.len() - 1] == (']' as u8) {
&x[1..x.len() - 1] &x[1..x.len() - 1]
} else { } else {
@ -234,12 +239,29 @@ impl<T: FromStr> FromStr for CommaSeparated<T> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{CommaSeparatedIterator, maybe_quoted, maybe_bracketed}; use crate::{maybe_bracketed, maybe_quoted, CommaSeparatedIterator};
#[test] #[test]
fn test_comma_separated_iterator() { fn test_comma_separated_iterator() {
assert_eq!(vec!["abc", "def", "ghi", "jkl ", "mno", "pqr"], CommaSeparatedIterator::new("abc,def, ghi,\tjkl , mno,\tpqr").collect::<Vec<&str>>()); assert_eq!(
assert_eq!(vec!["abc", "\"def\"", "\"ghi\"", "\"jkl\"", "\"mno\"", "pqr"], CommaSeparatedIterator::new("abc,\"def\", \"ghi\",\t\"jkl\" , \"mno\",\tpqr").collect::<Vec<&str>>()); vec!["abc", "def", "ghi", "jkl ", "mno", "pqr"],
CommaSeparatedIterator::new("abc,def, ghi,\tjkl , mno,\tpqr").collect::<Vec<&str>>()
);
assert_eq!(
vec![
"abc",
"\"def\"",
"\"ghi\"",
"\"jkl\"",
"\"mno\"",
"pqr",
"\"abc, def\""
],
CommaSeparatedIterator::new(
"abc,\"def\", \"ghi\",\t\"jkl\" , \"mno\",\tpqr, \"abc, def\""
)
.collect::<Vec<&str>>()
);
} }
#[test] #[test]
@ -256,5 +278,4 @@ mod tests {
assert_eq!("[abc", maybe_bracketed("[abc")); assert_eq!("[abc", maybe_bracketed("[abc"));
assert_eq!("abc]", maybe_bracketed("abc]")); assert_eq!("abc]", maybe_bracketed("abc]"));
} }
} }

View file

@ -83,3 +83,25 @@ async fn test_trusted_forwarded_no_for() {
.await; .await;
assert_eq!(res.body(), "1.2.3.4"); assert_eq!(res.body(), "1.2.3.4");
} }
#[tokio::test]
async fn test_quoted() {
let remote: IpAddr = [1, 2, 3, 4].into();
let res = warp::test::request()
.remote_addr((remote, 80).into())
.header("x-forwarded-for", "\"10.10.10.10\"")
.reply(&serve(vec![remote]))
.await;
assert_eq!(res.body(), "10.10.10.10");
}
#[tokio::test]
async fn test_nested_quoted() {
let remote: IpAddr = [1, 2, 3, 4].into();
let res = warp::test::request()
.remote_addr((remote, 80).into())
.header("x-forwarded-for", "\"10.10.10.10\", 11.11.11.11")
.reply(&serve(vec![remote, [10, 10, 10, 10].into()]))
.await;
assert_eq!(res.body(), "11.11.11.11");
}