diff --git a/.github/workflows/docker-capture.yml b/.github/workflows/docker-build.yml similarity index 77% rename from .github/workflows/docker-capture.yml rename to .github/workflows/docker-build.yml index d0efac3..3132a61 100644 --- a/.github/workflows/docker-capture.yml +++ b/.github/workflows/docker-build.yml @@ -1,4 +1,4 @@ -name: Build capture docker image +name: Build container images on: workflow_dispatch: @@ -6,12 +6,16 @@ on: branches: - "main" -permissions: - packages: write - jobs: build: - name: build and publish capture image + name: Build and publish container image + strategy: + matrix: + image: + - capture + - hook-api + - hook-janitor + - hook-worker runs-on: depot-ubuntu-22.04-4 permissions: id-token: write # allow issuing OIDC tokens for this workflow run @@ -45,7 +49,7 @@ jobs: id: meta uses: docker/metadata-action@v4 with: - images: ghcr.io/posthog/hog-rs/capture + images: ghcr.io/posthog/hog-rs/${{ matrix.image }} tags: | type=ref,event=pr type=ref,event=branch @@ -57,8 +61,8 @@ jobs: id: buildx uses: docker/setup-buildx-action@v2 - - name: Build and push capture - id: docker_build_capture + - name: Build and push image + id: docker_build uses: depot/build-push-action@v1 with: context: ./ @@ -69,7 +73,7 @@ jobs: platforms: linux/arm64 cache-from: type=gha cache-to: type=gha,mode=max - build-args: BIN=capture-server + build-args: BIN=${{ matrix.image }} - - name: Capture image digest - run: echo ${{ steps.docker_build_capture.outputs.digest }} + - name: Container image digest + run: echo ${{ steps.docker_build.outputs.digest }} diff --git a/.github/workflows/docker-hook-api.yml b/.github/workflows/docker-hook-api.yml deleted file mode 100644 index e6f83f1..0000000 --- a/.github/workflows/docker-hook-api.yml +++ /dev/null @@ -1,75 +0,0 @@ -name: Build hook-api docker image - -on: - workflow_dispatch: - push: - branches: - - "main" - -permissions: - packages: write - -jobs: - build: - name: build and publish hook-api image - runs-on: depot-ubuntu-22.04-4 - permissions: - id-token: write # allow issuing OIDC tokens for this workflow run - contents: read # allow reading the repo contents - packages: write # allow push to ghcr.io - - steps: - - name: Check Out Repo - uses: actions/checkout@v3 - - - name: Set up Depot CLI - uses: depot/setup-action@v1 - - - name: Login to DockerHub - uses: docker/login-action@v2 - with: - username: posthog - password: ${{ secrets.DOCKERHUB_TOKEN }} - - - name: Login to ghcr.io - uses: docker/login-action@v2 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - - name: Set up QEMU - uses: docker/setup-qemu-action@v2 - - - name: Docker meta - id: meta - uses: docker/metadata-action@v4 - with: - images: ghcr.io/posthog/hog-rs/hook-api - tags: | - type=ref,event=pr - type=ref,event=branch - type=semver,pattern={{version}} - type=semver,pattern={{major}}.{{minor}} - type=sha - - - name: Set up Docker Buildx - id: buildx - uses: docker/setup-buildx-action@v2 - - - name: Build and push api - id: docker_build_hook_api - uses: depot/build-push-action@v1 - with: - context: ./ - file: ./Dockerfile - push: true - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} - platforms: linux/arm64 - cache-from: type=gha - cache-to: type=gha,mode=max - build-args: BIN=hook-api - - - name: Hook-api image digest - run: echo ${{ steps.docker_build_hook_api.outputs.digest }} diff --git a/.github/workflows/docker-hook-janitor.yml b/.github/workflows/docker-hook-janitor.yml deleted file mode 100644 index 33706d0..0000000 --- a/.github/workflows/docker-hook-janitor.yml +++ /dev/null @@ -1,75 +0,0 @@ -name: Build hook-janitor docker image - -on: - workflow_dispatch: - push: - branches: - - "main" - -permissions: - packages: write - -jobs: - build: - name: build and publish hook-janitor image - runs-on: depot-ubuntu-22.04-4 - permissions: - id-token: write # allow issuing OIDC tokens for this workflow run - contents: read # allow reading the repo contents - packages: write # allow push to ghcr.io - - steps: - - name: Check Out Repo - uses: actions/checkout@v3 - - - name: Set up Depot CLI - uses: depot/setup-action@v1 - - - name: Login to DockerHub - uses: docker/login-action@v2 - with: - username: posthog - password: ${{ secrets.DOCKERHUB_TOKEN }} - - - name: Login to ghcr.io - uses: docker/login-action@v2 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - - name: Set up QEMU - uses: docker/setup-qemu-action@v2 - - - name: Docker meta - id: meta - uses: docker/metadata-action@v4 - with: - images: ghcr.io/posthog/hog-rs/hook-janitor - tags: | - type=ref,event=pr - type=ref,event=branch - type=semver,pattern={{version}} - type=semver,pattern={{major}}.{{minor}} - type=sha - - - name: Set up Docker Buildx - id: buildx - uses: docker/setup-buildx-action@v2 - - - name: Build and push janitor - id: docker_build_hook_janitor - uses: depot/build-push-action@v1 - with: - context: ./ - file: ./Dockerfile - push: true - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} - platforms: linux/arm64 - cache-from: type=gha - cache-to: type=gha,mode=max - build-args: BIN=hook-janitor - - - name: Hook-janitor image digest - run: echo ${{ steps.docker_build_hook_janitor.outputs.digest }} diff --git a/.github/workflows/docker-hook-worker.yml b/.github/workflows/docker-hook-worker.yml deleted file mode 100644 index dc5ca53..0000000 --- a/.github/workflows/docker-hook-worker.yml +++ /dev/null @@ -1,75 +0,0 @@ -name: Build hook-worker docker image - -on: - workflow_dispatch: - push: - branches: - - "main" - -permissions: - packages: write - -jobs: - build: - name: build and publish hook-worker image - runs-on: depot-ubuntu-22.04-4 - permissions: - id-token: write # allow issuing OIDC tokens for this workflow run - contents: read # allow reading the repo contents - packages: write # allow push to ghcr.io - - steps: - - name: Check Out Repo - uses: actions/checkout@v3 - - - name: Set up Depot CLI - uses: depot/setup-action@v1 - - - name: Login to DockerHub - uses: docker/login-action@v2 - with: - username: posthog - password: ${{ secrets.DOCKERHUB_TOKEN }} - - - name: Login to ghcr.io - uses: docker/login-action@v2 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - - name: Set up QEMU - uses: docker/setup-qemu-action@v2 - - - name: Docker meta - id: meta - uses: docker/metadata-action@v4 - with: - images: ghcr.io/posthog/hog-rs/hook-worker - tags: | - type=ref,event=pr - type=ref,event=branch - type=semver,pattern={{version}} - type=semver,pattern={{major}}.{{minor}} - type=sha - - - name: Set up Docker Buildx - id: buildx - uses: docker/setup-buildx-action@v2 - - - name: Build and push worker - id: docker_build_hook_worker - uses: depot/build-push-action@v1 - with: - context: ./ - file: ./Dockerfile - push: true - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} - platforms: linux/arm64 - cache-from: type=gha - cache-to: type=gha,mode=max - build-args: BIN=hook-worker - - - name: Hook-worker image digest - run: echo ${{ steps.docker_build_hook_worker.outputs.digest }} diff --git a/Cargo.lock b/Cargo.lock index 600a9e3..0f475fa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -370,13 +370,19 @@ dependencies = [ "bytes", "envconfig", "flate2", + "futures", "governor", "health", "metrics", "metrics-exporter-prometheus", + "once_cell", + "opentelemetry", + "opentelemetry-otlp", + "opentelemetry_sdk", "rand", "rdkafka", "redis", + "reqwest 0.12.3", "serde", "serde_json", "serde_urlencoded", @@ -385,30 +391,9 @@ dependencies = [ "tokio", "tower-http", "tracing", - "uuid", -] - -[[package]] -name = "capture-server" -version = "0.1.0" -dependencies = [ - "anyhow", - "assert-json-diff", - "capture", - "envconfig", - "futures", - "once_cell", - "opentelemetry", - "opentelemetry-otlp", - "opentelemetry_sdk", - "rand", - "rdkafka", - "reqwest 0.12.3", - "serde_json", - "tokio", - "tracing", "tracing-opentelemetry", "tracing-subscriber", + "uuid", ] [[package]] @@ -1099,6 +1084,7 @@ dependencies = [ "sqlx", "tokio", "tower", + "tower-http", "tracing", "tracing-subscriber", "url", diff --git a/Cargo.toml b/Cargo.toml index f1a4ff9..ea5d041 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,6 @@ resolver = "2" members = [ "capture", - "capture-server", "common/health", "feature-flags", "hook-api", @@ -14,7 +13,7 @@ members = [ [workspace.lints.rust] # See https://doc.rust-lang.org/stable/rustc/lints/listing/allowed-by-default.html -unsafe_code = "forbid" # forbid cannot be ignored with an annotation +unsafe_code = "forbid" # forbid cannot be ignored with an annotation unstable_features = "forbid" macro_use_extern_crate = "forbid" let_underscore_drop = "deny" @@ -45,6 +44,10 @@ http = { version = "1.1.0" } http-body-util = "0.1.0" metrics = "0.22.0" metrics-exporter-prometheus = "0.14.0" +once_cell = "1.18.0" +opentelemetry = { version = "0.22.0", features = ["trace"]} +opentelemetry-otlp = "0.15.0" +opentelemetry_sdk = { version = "0.22.1", features = ["trace", "rt-tokio"] } rand = "0.8.5" rdkafka = { version = "0.36.0", features = ["cmake-build", "ssl", "tracing"] } reqwest = { version = "0.12.3", features = ["json", "stream"] } @@ -70,8 +73,9 @@ time = { version = "0.3.20", features = [ thiserror = { version = "1.0" } tokio = { version = "1.34.0", features = ["full"] } tower = "0.4.13" -tower-http = { version = "0.5.2", features = ["cors", "trace"] } +tower-http = { version = "0.5.2", features = ["cors", "limit", "trace"] } tracing = "0.1.40" -tracing-subscriber = "0.3.18" +tracing-opentelemetry = "0.23.0" +tracing-subscriber = { version="0.3.18", features = ["env-filter"] } url = { version = "2.5.0 " } uuid = { version = "1.6.1", features = ["v7", "serde"] } diff --git a/capture-server/Cargo.toml b/capture-server/Cargo.toml deleted file mode 100644 index 39ee742..0000000 --- a/capture-server/Cargo.toml +++ /dev/null @@ -1,28 +0,0 @@ -[package] -name = "capture-server" -version = "0.1.0" -edition = "2021" - -[lints] -workspace = true - -[dependencies] -capture = { path = "../capture" } -envconfig = { workspace = true } -opentelemetry = { version = "0.22.0", features = ["trace"]} -opentelemetry-otlp = "0.15.0" -opentelemetry_sdk = { version = "0.22.1", features = ["trace", "rt-tokio"] } -tokio = { workspace = true } -tracing = { workspace = true } -tracing-opentelemetry = "0.23.0" -tracing-subscriber = { workspace = true, features = ["env-filter"] } - -[dev-dependencies] -anyhow = { workspace = true, features = [] } -assert-json-diff = { workspace = true } -futures = "0.3.29" -once_cell = "1.18.0" -rand = { workspace = true } -rdkafka = { workspace = true } -reqwest = { workspace = true } -serde_json = { workspace = true } diff --git a/capture/Cargo.toml b/capture/Cargo.toml index ae5ad9a..97d310f 100644 --- a/capture/Cargo.toml +++ b/capture/Cargo.toml @@ -19,6 +19,9 @@ governor = { workspace = true } health = { path = "../common/health" } metrics = { workspace = true } metrics-exporter-prometheus = { workspace = true } +opentelemetry = { workspace = true } +opentelemetry-otlp = { workspace = true } +opentelemetry_sdk = { workspace = true } rand = { workspace = true } rdkafka = { workspace = true } redis = { version = "0.23.3", features = [ @@ -34,8 +37,17 @@ time = { workspace = true } tokio = { workspace = true } tower-http = { workspace = true } tracing = { workspace = true } +tracing-opentelemetry = { workspace = true } +tracing-subscriber = { workspace = true } uuid = { workspace = true } [dev-dependencies] assert-json-diff = { workspace = true } axum-test-helper = { git = "https://github.com/posthog/axum-test-helper.git" } # TODO: remove, directly use reqwest like capture-server tests +anyhow = { workspace = true } +futures = { workspace = true } +once_cell = { workspace = true } +rand = { workspace = true } +rdkafka = { workspace = true } +reqwest = { workspace = true } +serde_json = { workspace = true } diff --git a/capture-server/src/main.rs b/capture/src/main.rs similarity index 100% rename from capture-server/src/main.rs rename to capture/src/main.rs diff --git a/capture-server/tests/common.rs b/capture/tests/common.rs similarity index 100% rename from capture-server/tests/common.rs rename to capture/tests/common.rs diff --git a/capture-server/tests/events.rs b/capture/tests/events.rs similarity index 100% rename from capture-server/tests/events.rs rename to capture/tests/events.rs diff --git a/hook-api/Cargo.toml b/hook-api/Cargo.toml index c3528d2..eb82438 100644 --- a/hook-api/Cargo.toml +++ b/hook-api/Cargo.toml @@ -19,6 +19,7 @@ serde_json = { workspace = true } sqlx = { workspace = true } tokio = { workspace = true } tower = { workspace = true } +tower-http = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } url = { workspace = true } diff --git a/hook-api/src/config.rs b/hook-api/src/config.rs index 55fa404..e15f0d3 100644 --- a/hook-api/src/config.rs +++ b/hook-api/src/config.rs @@ -16,6 +16,9 @@ pub struct Config { #[envconfig(default = "100")] pub max_pg_connections: u32, + + #[envconfig(default = "5000000")] + pub max_body_size: usize, } impl Config { diff --git a/hook-api/src/handlers/app.rs b/hook-api/src/handlers/app.rs index f8d4b24..7cbbc44 100644 --- a/hook-api/src/handlers/app.rs +++ b/hook-api/src/handlers/app.rs @@ -1,15 +1,21 @@ use axum::{routing, Router}; +use tower_http::limit::RequestBodyLimitLayer; use hook_common::pgqueue::PgQueue; use super::webhook; -pub fn add_routes(router: Router, pg_pool: PgQueue) -> Router { +pub fn add_routes(router: Router, pg_pool: PgQueue, max_body_size: usize) -> Router { router .route("/", routing::get(index)) .route("/_readiness", routing::get(index)) .route("/_liveness", routing::get(index)) // No async loop for now, just check axum health - .route("/webhook", routing::post(webhook::post).with_state(pg_pool)) + .route( + "/webhook", + routing::post(webhook::post) + .with_state(pg_pool) + .layer(RequestBodyLimitLayer::new(max_body_size)), + ) } pub async fn index() -> &'static str { @@ -32,7 +38,7 @@ mod tests { async fn index(db: PgPool) { let pg_queue = PgQueue::new_from_pool("test_index", db).await; - let app = add_routes(Router::new(), pg_queue); + let app = add_routes(Router::new(), pg_queue, 1_000_000); let response = app .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) diff --git a/hook-api/src/handlers/webhook.rs b/hook-api/src/handlers/webhook.rs index e50b8b0..808c948 100644 --- a/hook-api/src/handlers/webhook.rs +++ b/hook-api/src/handlers/webhook.rs @@ -9,8 +9,6 @@ use hook_common::pgqueue::{NewJob, PgQueue}; use serde::Serialize; use tracing::{debug, error}; -const MAX_BODY_SIZE: usize = 1_000_000; - #[derive(Serialize, Deserialize)] pub struct WebhookPostResponse { #[serde(skip_serializing_if = "Option::is_none")] @@ -37,15 +35,6 @@ pub async fn post( ) -> Result, (StatusCode, Json)> { debug!("received payload: {:?}", payload); - if payload.parameters.body.len() > MAX_BODY_SIZE { - return Err(( - StatusCode::BAD_REQUEST, - Json(WebhookPostResponse { - error: Some("body too large".to_owned()), - }), - )); - } - let url_hostname = get_hostname(&payload.parameters.url)?; // We could cast to i32, but this ensures we are not wrapping. let max_attempts = i32::try_from(payload.max_attempts).map_err(|_| { @@ -125,11 +114,13 @@ mod tests { use crate::handlers::app::add_routes; + const MAX_BODY_SIZE: usize = 1_000_000; + #[sqlx::test(migrations = "../migrations")] async fn webhook_success(db: PgPool) { let pg_queue = PgQueue::new_from_pool("test_index", db).await; - let app = add_routes(Router::new(), pg_queue); + let app = add_routes(Router::new(), pg_queue, MAX_BODY_SIZE); let mut headers = collections::HashMap::new(); headers.insert("Content-Type".to_owned(), "application/json".to_owned()); @@ -171,7 +162,7 @@ mod tests { async fn webhook_bad_url(db: PgPool) { let pg_queue = PgQueue::new_from_pool("test_index", db).await; - let app = add_routes(Router::new(), pg_queue); + let app = add_routes(Router::new(), pg_queue, MAX_BODY_SIZE); let response = app .oneshot( @@ -208,7 +199,7 @@ mod tests { async fn webhook_payload_missing_fields(db: PgPool) { let pg_queue = PgQueue::new_from_pool("test_index", db).await; - let app = add_routes(Router::new(), pg_queue); + let app = add_routes(Router::new(), pg_queue, MAX_BODY_SIZE); let response = app .oneshot( @@ -229,7 +220,7 @@ mod tests { async fn webhook_payload_not_json(db: PgPool) { let pg_queue = PgQueue::new_from_pool("test_index", db).await; - let app = add_routes(Router::new(), pg_queue); + let app = add_routes(Router::new(), pg_queue, MAX_BODY_SIZE); let response = app .oneshot( @@ -250,9 +241,9 @@ mod tests { async fn webhook_payload_body_too_large(db: PgPool) { let pg_queue = PgQueue::new_from_pool("test_index", db).await; - let app = add_routes(Router::new(), pg_queue); + let app = add_routes(Router::new(), pg_queue, MAX_BODY_SIZE); - let bytes: Vec = vec![b'a'; 1_000_000 * 2]; + let bytes: Vec = vec![b'a'; MAX_BODY_SIZE + 1]; let long_string = String::from_utf8_lossy(&bytes); let response = app @@ -283,6 +274,6 @@ mod tests { .await .unwrap(); - assert_eq!(response.status(), StatusCode::BAD_REQUEST); + assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE); } } diff --git a/hook-api/src/main.rs b/hook-api/src/main.rs index 9a9a9fd..ad05ede 100644 --- a/hook-api/src/main.rs +++ b/hook-api/src/main.rs @@ -34,7 +34,7 @@ async fn main() { .await .expect("failed to initialize queue"); - let app = handlers::add_routes(Router::new(), pg_queue); + let app = handlers::add_routes(Router::new(), pg_queue, config.max_body_size); let app = setup_metrics_routes(app); match listen(app, config.bind()).await { diff --git a/hook-common/src/kafka_messages/plugin_logs.rs b/hook-common/src/kafka_messages/plugin_logs.rs index 5a852e6..039788a 100644 --- a/hook-common/src/kafka_messages/plugin_logs.rs +++ b/hook-common/src/kafka_messages/plugin_logs.rs @@ -4,7 +4,6 @@ use uuid::Uuid; use super::serialize_datetime; -#[allow(dead_code)] #[derive(Serialize)] pub enum PluginLogEntrySource { System, @@ -12,7 +11,6 @@ pub enum PluginLogEntrySource { Console, } -#[allow(dead_code)] #[derive(Serialize)] pub enum PluginLogEntryType { Debug, diff --git a/hook-worker/src/config.rs b/hook-worker/src/config.rs index ceb690f..51b23b7 100644 --- a/hook-worker/src/config.rs +++ b/hook-worker/src/config.rs @@ -37,6 +37,9 @@ pub struct Config { #[envconfig(default = "1")] pub dequeue_batch_size: u32, + + #[envconfig(default = "false")] + pub allow_internal_ips: bool, } impl Config { diff --git a/hook-worker/src/dns.rs b/hook-worker/src/dns.rs index e69de29..36fd7a0 100644 --- a/hook-worker/src/dns.rs +++ b/hook-worker/src/dns.rs @@ -0,0 +1,140 @@ +use std::error::Error as StdError; +use std::net::{IpAddr, SocketAddr, ToSocketAddrs}; +use std::{fmt, io}; + +use futures::FutureExt; +use reqwest::dns::{Addrs, Name, Resolve, Resolving}; +use tokio::task::spawn_blocking; + +pub struct NoPublicIPv4Error; + +impl std::error::Error for NoPublicIPv4Error {} +impl fmt::Display for NoPublicIPv4Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "No public IPv4 found for specified host") + } +} +impl fmt::Debug for NoPublicIPv4Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "No public IPv4 found for specified host") + } +} + +/// Internal reqwest type, copied here as part of Resolving +pub(crate) type BoxError = Box; + +/// Returns [`true`] if the address appears to be a globally reachable IPv4. +/// +/// Trimmed down version of the unstable IpAddr::is_global, move to it when it's stable. +fn is_global_ipv4(addr: &SocketAddr) -> bool { + match addr.ip() { + IpAddr::V4(ip) => { + !(ip.octets()[0] == 0 // "This network" + || ip.is_private() + || ip.is_loopback() + || ip.is_link_local() + || ip.is_broadcast()) + } + IpAddr::V6(_) => false, // Our network does not currently support ipv6, let's ignore for now + } +} + +/// DNS resolver using the stdlib resolver, but filtering results to only pass public IPv4 results. +/// +/// Private and broadcast addresses are filtered out, so are IPv6 results for now (as our infra +/// does not currently support IPv6 routing anyway). +/// This is adapted from the GaiResolver in hyper and reqwest. +pub struct PublicIPv4Resolver {} + +impl Resolve for PublicIPv4Resolver { + fn resolve(&self, name: Name) -> Resolving { + // Closure to call the system's resolver (blocking call) through the ToSocketAddrs trait. + let resolve_host = move || (name.as_str(), 0).to_socket_addrs(); + + // Execute the blocking call in a separate worker thread then process its result asynchronously. + // spawn_blocking returns a JoinHandle that implements Future>. + let future_result = spawn_blocking(resolve_host).map(|result| match result { + Ok(Ok(all_addrs)) => { + // Resolution succeeded, filter the results + let filtered_addr: Vec = all_addrs.filter(is_global_ipv4).collect(); + if filtered_addr.is_empty() { + // No public IPs found, error out with PermissionDenied + let err: BoxError = Box::new(NoPublicIPv4Error); + Err(err) + } else { + // Pass remaining IPs in a boxed iterator for request to use. + let addrs: Addrs = Box::new(filtered_addr.into_iter()); + Ok(addrs) + } + } + Ok(Err(err)) => { + // Resolution failed, pass error through in a Box + let err: BoxError = Box::new(err); + Err(err) + } + Err(join_err) => { + // The tokio task failed, pass as io::Error in a Box + let err: BoxError = Box::new(io::Error::from(join_err)); + Err(err) + } + }); + + // Box the Future to satisfy the Resolving interface. + Box::pin(future_result) + } +} + +#[cfg(test)] +mod tests { + use crate::dns::{NoPublicIPv4Error, PublicIPv4Resolver}; + use reqwest::dns::{Name, Resolve}; + use std::str::FromStr; + + #[tokio::test] + async fn it_resolves_google_com() { + let resolver: PublicIPv4Resolver = PublicIPv4Resolver {}; + let addrs = resolver + .resolve(Name::from_str("google.com").unwrap()) + .await + .expect("lookup has failed"); + assert!(addrs.count() > 0, "empty address list") + } + + #[tokio::test] + async fn it_denies_ipv6_google_com() { + let resolver: PublicIPv4Resolver = PublicIPv4Resolver {}; + match resolver + .resolve(Name::from_str("ipv6.google.com").unwrap()) + .await + { + Ok(_) => panic!("should have failed"), + Err(err) => assert!(err.is::()), + } + } + + #[tokio::test] + async fn it_denies_localhost() { + let resolver: PublicIPv4Resolver = PublicIPv4Resolver {}; + match resolver.resolve(Name::from_str("localhost").unwrap()).await { + Ok(_) => panic!("should have failed"), + Err(err) => assert!(err.is::()), + } + } + + #[tokio::test] + async fn it_bubbles_up_resolution_error() { + let resolver: PublicIPv4Resolver = PublicIPv4Resolver {}; + match resolver + .resolve(Name::from_str("invalid.domain.unknown").unwrap()) + .await + { + Ok(_) => panic!("should have failed"), + Err(err) => { + assert!(!err.is::()); + assert!(err + .to_string() + .contains("failed to lookup address information")) + } + } + } +} diff --git a/hook-worker/src/error.rs b/hook-worker/src/error.rs index 48468bc..764e8d9 100644 --- a/hook-worker/src/error.rs +++ b/hook-worker/src/error.rs @@ -1,6 +1,8 @@ +use std::error::Error; use std::fmt; use std::time; +use crate::dns::NoPublicIPv4Error; use hook_common::{pgqueue, webhook::WebhookJobError}; use thiserror::Error; @@ -64,7 +66,11 @@ impl fmt::Display for WebhookRequestError { Some(m) => m.to_string(), None => "No response from the server".to_string(), }; - writeln!(f, "{}", error)?; + if is_error_source::(error) { + writeln!(f, "{}: {}", error, NoPublicIPv4Error)?; + } else { + writeln!(f, "{}", error)?; + } write!(f, "{}", response_message)?; Ok(()) @@ -132,3 +138,15 @@ pub enum WorkerError { #[error("timed out while waiting for jobs to be available")] TimeoutError, } + +/// Check the error and it's sources (recursively) to return true if an error of the given type is found. +/// TODO: use Error::sources() when stable +pub fn is_error_source(err: &(dyn std::error::Error + 'static)) -> bool { + if err.is::() { + return true; + } + match err.source() { + None => false, + Some(source) => is_error_source::(source), + } +} diff --git a/hook-worker/src/lib.rs b/hook-worker/src/lib.rs index 8488d15..94a0758 100644 --- a/hook-worker/src/lib.rs +++ b/hook-worker/src/lib.rs @@ -1,4 +1,5 @@ pub mod config; +pub mod dns; pub mod error; pub mod util; pub mod worker; diff --git a/hook-worker/src/main.rs b/hook-worker/src/main.rs index 8a6eeb3..050e2b9 100644 --- a/hook-worker/src/main.rs +++ b/hook-worker/src/main.rs @@ -52,6 +52,7 @@ async fn main() -> Result<(), WorkerError> { config.request_timeout.0, config.max_concurrent_jobs, retry_policy_builder.provide(), + config.allow_internal_ips, worker_liveness, ); diff --git a/hook-worker/src/worker.rs b/hook-worker/src/worker.rs index 824f1e2..9dcc4a2 100644 --- a/hook-worker/src/worker.rs +++ b/hook-worker/src/worker.rs @@ -7,18 +7,19 @@ use futures::future::join_all; use health::HealthHandle; use hook_common::pgqueue::PgTransactionBatch; use hook_common::{ - pgqueue::{ - DatabaseError, Job, PgQueue, PgQueueJob, PgTransactionJob, RetryError, RetryInvalidError, - }, + pgqueue::{Job, PgQueue, PgQueueJob, PgTransactionJob, RetryError, RetryInvalidError}, retry::RetryPolicy, webhook::{HttpMethod, WebhookJobError, WebhookJobMetadata, WebhookJobParameters}, }; use http::StatusCode; -use reqwest::header; +use reqwest::{header, Client}; use tokio::sync; use tracing::error; -use crate::error::{WebhookError, WebhookParseError, WebhookRequestError, WorkerError}; +use crate::dns::{NoPublicIPv4Error, PublicIPv4Resolver}; +use crate::error::{ + is_error_source, WebhookError, WebhookParseError, WebhookRequestError, WorkerError, +}; use crate::util::first_n_bytes_of_response; /// A WebhookJob is any `PgQueueJob` with `WebhookJobParameters` and `WebhookJobMetadata`. @@ -74,6 +75,25 @@ pub struct WebhookWorker<'p> { liveness: HealthHandle, } +pub fn build_http_client( + request_timeout: time::Duration, + allow_internal_ips: bool, +) -> reqwest::Result { + let mut headers = header::HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ); + let mut client_builder = reqwest::Client::builder() + .default_headers(headers) + .user_agent("PostHog Webhook Worker") + .timeout(request_timeout); + if !allow_internal_ips { + client_builder = client_builder.dns_resolver(Arc::new(PublicIPv4Resolver {})) + } + client_builder.build() +} + impl<'p> WebhookWorker<'p> { #[allow(clippy::too_many_arguments)] pub fn new( @@ -84,19 +104,10 @@ impl<'p> WebhookWorker<'p> { request_timeout: time::Duration, max_concurrent_jobs: usize, retry_policy: RetryPolicy, + allow_internal_ips: bool, liveness: HealthHandle, ) -> Self { - let mut headers = header::HeaderMap::new(); - headers.insert( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json"), - ); - - let client = reqwest::Client::builder() - .default_headers(headers) - .user_agent("PostHog Webhook Worker") - .timeout(request_timeout) - .build() + let client = build_http_client(request_timeout, allow_internal_ips) .expect("failed to construct reqwest client for webhook worker"); Self { @@ -390,10 +401,19 @@ async fn send_webhook( .body(body) .send() .await - .map_err(|e| WebhookRequestError::RetryableRequestError { - error: e, - response: None, - retry_after: None, + .map_err(|e| { + if is_error_source::(&e) { + WebhookRequestError::NonRetryableRetryableRequestError { + error: e, + response: None, + } + } else { + WebhookRequestError::RetryableRequestError { + error: e, + response: None, + retry_after: None, + } + } })?; let retry_after = parse_retry_after_header(response.headers()); @@ -467,25 +487,27 @@ fn parse_retry_after_header(header_map: &reqwest::header::HeaderMap) -> Option String { std::process::id().to_string() } - #[allow(dead_code)] + /// Get a request client or panic + fn localhost_client() -> Client { + build_http_client(Duration::from_secs(1), true).expect("failed to create client") + } + async fn enqueue_job( queue: &PgQueue, max_attempts: i32, @@ -569,6 +591,7 @@ mod tests { time::Duration::from_millis(5000), 10, RetryPolicy::default(), + false, liveness, ); @@ -594,15 +617,14 @@ mod tests { assert!(registry.get_status().healthy) } - #[sqlx::test(migrations = "../migrations")] - async fn test_send_webhook(_pg: PgPool) { + #[tokio::test] + async fn test_send_webhook() { let method = HttpMethod::POST; let url = "http://localhost:18081/echo"; let headers = collections::HashMap::new(); let body = "a very relevant request body"; - let client = reqwest::Client::new(); - let response = send_webhook(client, &method, url, &headers, body.to_owned()) + let response = send_webhook(localhost_client(), &method, url, &headers, body.to_owned()) .await .expect("send_webhook failed"); @@ -613,15 +635,14 @@ mod tests { ); } - #[sqlx::test(migrations = "../migrations")] - async fn test_error_message_contains_response_body(_pg: PgPool) { + #[tokio::test] + async fn test_error_message_contains_response_body() { let method = HttpMethod::POST; let url = "http://localhost:18081/fail"; let headers = collections::HashMap::new(); let body = "this is an error message"; - let client = reqwest::Client::new(); - let err = send_webhook(client, &method, url, &headers, body.to_owned()) + let err = send_webhook(localhost_client(), &method, url, &headers, body.to_owned()) .await .err() .expect("request didn't fail when it should have failed"); @@ -638,17 +659,16 @@ mod tests { } } - #[sqlx::test(migrations = "../migrations")] - async fn test_error_message_contains_up_to_n_bytes_of_response_body(_pg: PgPool) { + #[tokio::test] + async fn test_error_message_contains_up_to_n_bytes_of_response_body() { let method = HttpMethod::POST; let url = "http://localhost:18081/fail"; let headers = collections::HashMap::new(); // This is double the current hardcoded amount of bytes. // TODO: Make this configurable and change it here too. let body = (0..20 * 1024).map(|_| "a").collect::>().concat(); - let client = reqwest::Client::new(); - let err = send_webhook(client, &method, url, &headers, body.to_owned()) + let err = send_webhook(localhost_client(), &method, url, &headers, body.to_owned()) .await .err() .expect("request didn't fail when it should have failed"); @@ -666,4 +686,32 @@ mod tests { )); } } + + #[tokio::test] + async fn test_private_ips_denied() { + let method = HttpMethod::POST; + let url = "http://localhost:18081/echo"; + let headers = collections::HashMap::new(); + let body = "a very relevant request body"; + let filtering_client = + build_http_client(Duration::from_secs(1), false).expect("failed to create client"); + + let err = send_webhook(filtering_client, &method, url, &headers, body.to_owned()) + .await + .err() + .expect("request didn't fail when it should have failed"); + + assert!(matches!(err, WebhookError::Request(..))); + if let WebhookError::Request(request_error) = err { + assert_eq!(request_error.status(), None); + assert!(request_error + .to_string() + .contains("No public IPv4 found for specified host")); + if let WebhookRequestError::RetryableRequestError { .. } = request_error { + panic!("error should not be retryable") + } + } else { + panic!("unexpected error type {:?}", err) + } + } }