Skip to content

Commit 50318a6

Browse files
committed
add server::congestion_control config setting
1 parent 76a3a78 commit 50318a6

File tree

5 files changed

+73
-5
lines changed

5 files changed

+73
-5
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ regex = "1.5.4"
6464
rustls-pemfile = "1.0.0"
6565
serde = { version = "1.0.125", features = ["derive"] }
6666
serde_json = "1.0.64"
67-
socket2 = "0.4.0"
67+
socket2 = { version = "0.5.5", features = ["all"] }
6868
time = "0.1.42"
6969
tls-listener = { version = "0.5.1", features = [ "hyper-h1", "hyper-h2", "rustls" ] }
7070
tokio = { version = "1.5.0", features = ["full"] }

src/config.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ pub struct Server {
4545
pub identification: Option<String>,
4646
#[serde(default)]
4747
pub cors: bool,
48+
#[serde(default)]
49+
pub congestion_control: Option<String>,
4850
}
4951

5052
#[derive(Deserialize, Debug, Clone, Default)]

src/main.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ mod rootfs;
2121
#[doc(hidden)]
2222
pub mod router;
2323
mod suid;
24+
mod tcp_cong;
2425
mod tls;
2526
mod unixuser;
2627
mod userfs;
@@ -458,7 +459,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
458459

459460
// Plaintext servers.
460461
for sockaddr in addrs {
461-
let listener = match make_listener(sockaddr) {
462+
let listener = match make_listener(&config.server, sockaddr) {
462463
Ok(l) => l,
463464
Err(e) => {
464465
eprintln!("{}: listener on {:?}: {}", PROGNAME, &sockaddr, e);
@@ -495,7 +496,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
495496

496497
for sockaddr in tls_addrs {
497498
let tls_acceptor = tls_acceptor.clone();
498-
let listener = make_listener(sockaddr).unwrap_or_else(|e| {
499+
let listener = make_listener(&config.server, sockaddr).unwrap_or_else(|e| {
499500
eprintln!("{}: listener on {:?}: {}", PROGNAME, &sockaddr, e);
500501
exit(1);
501502
});
@@ -624,9 +625,17 @@ fn expand_directory(dir: &str, pwd: Option<&Arc<unixuser::User>>) -> Result<Stri
624625

625626
// Make a new TcpListener, and if it's a V6 listener, set the
626627
// V6_V6ONLY socket option on it.
627-
fn make_listener(addr: SocketAddr) -> io::Result<tokio::net::TcpListener> {
628+
fn make_listener(srv: &config::Server, addr: SocketAddr) -> io::Result<tokio::net::TcpListener> {
628629
use socket2::{Domain, SockAddr, Socket, Type, Protocol};
629630
let s = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?;
631+
if let Some(cong) = srv.congestion_control.as_ref() {
632+
tcp_cong::set_congestion_control(&s, cong).map_err(|e| {
633+
io::Error::new(
634+
io::ErrorKind::InvalidData,
635+
format!("congestion control {}: {}", cong, e),
636+
)
637+
})?;
638+
}
630639
if addr.is_ipv6() {
631640
s.set_only_v6(true)?;
632641
}

src/tcp_cong.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
use std::io;
2+
use socket2::Socket;
3+
4+
#[cfg(not(any(target_os = "freebsd", target_os = "linux")))]
5+
pub fn set_congestion_control(_sock: &Socket, _algo: &str) -> io::Result<()> {
6+
return Err(io::Error::new(io::ErrorKind::InvalidInput, "not implemented"));
7+
}
8+
9+
#[cfg(target_os = "linux")]
10+
pub fn set_congestion_control(sock: &Socket, algo: &str) -> io::Result<()> {
11+
sock.set_tcp_congestion(algo.as_bytes())
12+
}
13+
14+
#[cfg(target_os = "freebsd")]
15+
pub fn set_congestion_control(sock: &Socket, algo: &str) -> io::Result<()> {
16+
match algo {
17+
"bbr" => set_tcp_functions(sock, algo)?,
18+
"rack" => set_tcp_functions(sock, algo)?,
19+
_ => {
20+
set_tcp_functions(sock, "freebsd")?;
21+
sock.set_tcp_congestion(algo.as_bytes())?;
22+
},
23+
}
24+
Ok(())
25+
}
26+
27+
#[cfg(target_os = "freebsd")]
28+
fn set_tcp_functions(sock: &Socket, funcs: &str) -> io::Result<()> {
29+
const TCP_FUNCTION_BLK: libc::c_int = 8192;
30+
let slen = funcs.len();
31+
if slen >= libc::TCP_FUNCTION_NAME_LEN_MAX as usize {
32+
return Err(io::Error::new(io::ErrorKind::InvalidInput, "name too long"));
33+
}
34+
let mut function_set_name = [0 as libc::c_char; libc::TCP_FUNCTION_NAME_LEN_MAX as usize];
35+
let bytes = funcs.as_bytes();
36+
for idx in 0..slen {
37+
function_set_name[idx] = bytes[idx] as libc::c_char;
38+
}
39+
let fsn = libc::tcp_function_set {
40+
function_set_name,
41+
pcbcnt: 0,
42+
};
43+
use std::os::fd::AsRawFd;
44+
let res = unsafe {
45+
libc::setsockopt(
46+
sock.as_raw_fd(),
47+
libc::IPPROTO_TCP,
48+
TCP_FUNCTION_BLK,
49+
&fsn as *const libc::tcp_function_set as *const libc::c_void,
50+
std::mem::size_of_val(&fsn) as libc::socklen_t,
51+
)
52+
};
53+
if res == 0 {
54+
return Ok(());
55+
}
56+
Err(io::Error::last_os_error())
57+
}

0 commit comments

Comments
 (0)