Skip to content

Commit

Permalink
Introduce tcp_ca example
Browse files Browse the repository at this point in the history
Introduce the tcp_ca example, showing how to use BPF struct_ops to
inject a dummy TCP congestion algorithm into the kernel and use it
during a loopback based data exchange.

Signed-off-by: Daniel Müller <[email protected]>
  • Loading branch information
d-e-s-o authored and danielocfb committed Mar 15, 2024
1 parent 2b10652 commit b2c4c02
Show file tree
Hide file tree
Showing 10 changed files with 229 additions and 1 deletion.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
members = [
"libbpf-cargo",
"libbpf-rs",
"examples/runqslower",
"examples/bpf_query",
"examples/capable",
"examples/runqslower",
"examples/tc_port_whitelist",
"examples/tcp_ca",
"examples/tproxy",
]
resolver = "2"
14 changes: 14 additions & 0 deletions examples/tcp_ca/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "tcp_ca"
version = "0.0.0"
authors = ["Daniel Müller <[email protected]>"]
license = "LGPL-2.1-only OR BSD-2-Clause"
edition = "2021"

[build-dependencies]
libbpf-cargo = { path = "../../libbpf-cargo" }

[dependencies]
clap = { version = "4.0.32", default-features = false, features = ["std", "derive", "help", "usage"] }
libbpf-rs = { path = "../../libbpf-rs" }
libc = "0.2"
1 change: 1 addition & 0 deletions examples/tcp_ca/LICENSE
1 change: 1 addition & 0 deletions examples/tcp_ca/LICENSE.BSD-2-Clause
1 change: 1 addition & 0 deletions examples/tcp_ca/LICENSE.LGPL-2.1
2 changes: 2 additions & 0 deletions examples/tcp_ca/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
A simple example using the BPF `struct_ops` facility to create and use a dummy
TCP congestion algorithm.
27 changes: 27 additions & 0 deletions examples/tcp_ca/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use std::env;
use std::ffi::OsStr;
use std::path::Path;
use std::path::PathBuf;

use libbpf_cargo::SkeletonBuilder;

const SRC: &str = "src/bpf/tcp_ca.bpf.c";

fn main() {
let mut out =
PathBuf::from(env::var_os("OUT_DIR").expect("OUT_DIR must be set in build script"));
out.push("tcp_ca.skel.rs");

let arch = env::var("CARGO_CFG_TARGET_ARCH")
.expect("CARGO_CFG_TARGET_ARCH must be set in build script");

SkeletonBuilder::new()
.source(SRC)
.clang_args([
OsStr::new("-I"),
Path::new("../vmlinux").join(arch).as_os_str(),
])
.build_and_generate(&out)
.unwrap();
println!("cargo:rerun-if-changed={SRC}");
}
48 changes: 48 additions & 0 deletions examples/tcp_ca/src/bpf/tcp_ca.bpf.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// SPDX-License-Identifier: GPL-2.0

#include "vmlinux.h"

#include <bpf/bpf_helpers.h>
#include <bpf/bpf_tracing.h>

char _license[] SEC("license") = "GPL";

int ca_cnt = 0;

static inline struct tcp_sock *tcp_sk(const struct sock *sk)
{
return (struct tcp_sock *)sk;
}

SEC("struct_ops/ca_update_init")
void BPF_PROG(ca_update_init, struct sock *sk)
{
ca_cnt++;
}

SEC("struct_ops/ca_update_cong_control")
void BPF_PROG(ca_update_cong_control, struct sock *sk,
const struct rate_sample *rs)
{
}

SEC("struct_ops/ca_update_ssthresh")
__u32 BPF_PROG(ca_update_ssthresh, struct sock *sk)
{
return tcp_sk(sk)->snd_ssthresh;
}

SEC("struct_ops/ca_update_undo_cwnd")
__u32 BPF_PROG(ca_update_undo_cwnd, struct sock *sk)
{
return tcp_sk(sk)->snd_cwnd;
}

SEC(".struct_ops")
struct tcp_congestion_ops ca_update = {
.init = (void *)ca_update_init,
.cong_control = (void *)ca_update_cong_control,
.ssthresh = (void *)ca_update_ssthresh,
.undo_cwnd = (void *)ca_update_undo_cwnd,
.name = "tcp_ca_update",
};
123 changes: 123 additions & 0 deletions examples/tcp_ca/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// SPDX-License-Identifier: (LGPL-2.1 OR BSD-2-Clause)

#![allow(clippy::let_unit_value)]

use std::ffi::c_int;
use std::ffi::c_void;
use std::io;
use std::io::Read as _;
use std::io::Write as _;
use std::net::TcpListener;
use std::net::TcpStream;
use std::os::fd::AsFd as _;
use std::os::fd::AsRawFd as _;
use std::os::fd::BorrowedFd;
use std::thread;

use clap::Parser;

use libc::setsockopt;
use libc::IPPROTO_TCP;
use libc::TCP_CONGESTION;

use libbpf_rs::skel::OpenSkel;
use libbpf_rs::skel::SkelBuilder;
use libbpf_rs::ErrorExt as _;
use libbpf_rs::Result;

use crate::tcp_ca::TcpCaSkelBuilder;

mod tcp_ca {
include!(concat!(env!("OUT_DIR"), "/tcp_ca.skel.rs"));
}

const TCP_CA_UPDATE: &[u8] = b"tcp_ca_update\0";

/// An example program adding a TCP congestion algorithm.
#[derive(Debug, Parser)]
struct Args {
/// Verbose debug output
#[arg(short, long)]
verbose: bool,
}

fn set_sock_opt(
fd: BorrowedFd<'_>,
level: c_int,
name: c_int,
value: *const c_void,
opt_len: usize,
) -> Result<()> {
let rc = unsafe { setsockopt(fd.as_raw_fd(), level, name, value, opt_len as _) };
if rc == 0 {
Ok(())
} else {
Err(io::Error::last_os_error().into())
}
}

/// Set the `tcp_ca_update` congestion algorithm on the socket represented by
/// the provided file descriptor.
fn set_tcp_ca(fd: BorrowedFd<'_>) -> Result<()> {
let () = set_sock_opt(
fd,
IPPROTO_TCP,
TCP_CONGESTION,
TCP_CA_UPDATE.as_ptr().cast(),
(TCP_CA_UPDATE.len() - 1) as _,
)
.context("failed to set TCP_CONGESTION")?;
Ok(())
}

/// Send and receive a bunch of data over TCP sockets using the `tcp_ca_update`
/// congestion algorithm.
fn send_recv() -> Result<()> {
let num_bytes = 8 * 1024 * 1024;
let listener = TcpListener::bind("[::1]:0")?;
let () = set_tcp_ca(listener.as_fd())?;
let addr = listener.local_addr()?;

let send_handle = thread::spawn(move || {
let (mut stream, _addr) = listener.accept().unwrap();
let to_send = (0..num_bytes).map(|_| b'x').collect::<Vec<u8>>();
let () = stream.write_all(&to_send).unwrap();
});

let mut received = Vec::new();
let mut stream = TcpStream::connect(addr)?;
let () = set_tcp_ca(stream.as_fd())?;
let _count = stream.read_to_end(&mut received)?;
let () = send_handle.join().unwrap();

assert_eq!(received.len(), num_bytes);
Ok(())
}

fn main() -> Result<()> {
let args = Args::parse();

let mut skel_builder = TcpCaSkelBuilder::default();
if args.verbose {
skel_builder.obj_builder.debug(true);
}

let open_skel = skel_builder.open()?;
let mut skel = open_skel.load()?;
let mut maps = skel.maps_mut();
let map = maps.ca_update();
let _link = map.attach_struct_ops()?;

println!("Registered `tcp_ca_update` congestion algorithm; using it for loopback based data exchange...");

assert_eq!(skel.bss().ca_cnt, 0);

// Use our registered TCP congestion algorithm while sending a bunch of data
// over the loopback device.
let () = send_recv()?;
println!("Done.");

let saved_ca1_cnt = skel.bss().ca_cnt;
assert_ne!(saved_ca1_cnt, 0);
Ok(())
}

0 comments on commit b2c4c02

Please sign in to comment.