Skip to content
This repository has been archived by the owner on Jun 21, 2024. It is now read-only.

Commit

Permalink
add PublicIPv4Resolver error type and make sure it is logged, add wor…
Browse files Browse the repository at this point in the history
…ker test
  • Loading branch information
xvello committed May 6, 2024
1 parent f5f0a91 commit e040018
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 36 deletions.
32 changes: 27 additions & 5 deletions hook-worker/src/dns.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
use std::error::Error as StdError;
use std::io;
use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
use std::{fmt, io};

use futures::FutureExt;
use reqwest::dns::{Addrs, Name, Resolve, Resolving};
use tokio::task::spawn_blocking;

pub struct NoPublicIPError;

impl std::error::Error for NoPublicIPError {}
impl fmt::Display for NoPublicIPError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "No public IPv4 found for specified host")
}
}
impl fmt::Debug for NoPublicIPError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "No public IPv4 found for specified host")
}
}

/// Internal reqwest type, copied here as part of Resolving
pub(crate) type BoxError = Box<dyn StdError + Send + Sync>;

Expand Down Expand Up @@ -40,10 +54,18 @@ impl Resolve for PublicIPv4Resolver {
// Execute the blocking call in a separate worker thread then process its result asynchronously.
// spawn_blocking returns a JoinHandle that implements Future<Result<(closure result), JoinError>>.
let future_result = spawn_blocking(resolve_host).map(|result| match result {
Ok(Ok(addr)) => {
// Resolution succeeded, pass the IPs in a Box after filtering
let addrs: Addrs = Box::new(addr.filter(is_global_ipv4));
Ok(addrs)
Ok(Ok(all_addrs)) => {
// Resolution succeeded, filter the results
let filtered_addr: Vec<SocketAddr> = all_addrs.filter(is_global_ipv4).collect();
if filtered_addr.is_empty() {
// No public IPs found, error out with PermissionDenied
let err: BoxError = Box::new(NoPublicIPError);
Err(err)
} else {
// Pass remaining IPs in a boxed iterator for request to use.
let addrs: Addrs = Box::new(filtered_addr.into_iter());
Ok(addrs)
}
}
Ok(Err(err)) => {
// Resolution failed, pass error through in a Box
Expand Down
19 changes: 18 additions & 1 deletion hook-worker/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::error::Error;
use std::fmt;
use std::time;

use crate::dns::NoPublicIPError;
use hook_common::{pgqueue, webhook::WebhookJobError};
use thiserror::Error;

Expand Down Expand Up @@ -64,7 +66,11 @@ impl fmt::Display for WebhookRequestError {
Some(m) => m.to_string(),
None => "No response from the server".to_string(),
};
writeln!(f, "{}", error)?;
if is_error_source::<NoPublicIPError>(error) {
writeln!(f, "{}: {}", error ,NoPublicIPError)?;
} else {
writeln!(f, "{}", error)?;
}
write!(f, "{}", response_message)?;

Ok(())
Expand Down Expand Up @@ -132,3 +138,14 @@ pub enum WorkerError {
#[error("timed out while waiting for jobs to be available")]
TimeoutError,
}

/// Check the error and it's sources (recursively) to return true if an error of the given type is found.
pub fn is_error_source<T: Error + 'static>(err: &(dyn std::error::Error + 'static)) -> bool {
if err.downcast_ref::<T>().is_some() {
return true;
}
match err.source() {
None => false,
Some(source) => is_error_source::<T>(source),
}
}
94 changes: 64 additions & 30 deletions hook-worker/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use hook_common::{
webhook::{HttpMethod, WebhookJobError, WebhookJobMetadata, WebhookJobParameters},
};
use http::StatusCode;
use reqwest::header;
use reqwest::{header, Client};
use tokio::sync;
use tracing::error;

Expand Down Expand Up @@ -75,6 +75,25 @@ pub struct WebhookWorker<'p> {
liveness: HealthHandle,
}

pub fn build_http_client(
request_timeout: time::Duration,
allow_internal_ips: bool,
) -> reqwest::Result<Client> {
let mut headers = header::HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
let mut client_builder = reqwest::Client::builder()
.default_headers(headers)
.user_agent("PostHog Webhook Worker")
.timeout(request_timeout);
if !allow_internal_ips {
client_builder = client_builder.dns_resolver(Arc::new(PublicIPv4Resolver {}))
}
client_builder.build()
}

impl<'p> WebhookWorker<'p> {
#[allow(clippy::too_many_arguments)]
pub fn new(
Expand All @@ -88,21 +107,7 @@ impl<'p> WebhookWorker<'p> {
allow_internal_ips: bool,
liveness: HealthHandle,
) -> Self {
let mut headers = header::HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);

let mut client_builder = reqwest::Client::builder()
.default_headers(headers)
.user_agent("PostHog Webhook Worker")
.timeout(request_timeout);
if !allow_internal_ips {
client_builder = client_builder.dns_resolver(Arc::new(PublicIPv4Resolver {}))
}
let client = client_builder
.build()
let client = build_http_client(request_timeout, allow_internal_ips)
.expect("failed to construct reqwest client for webhook worker");

Self {
Expand Down Expand Up @@ -475,6 +480,7 @@ fn parse_retry_after_header(header_map: &reqwest::header::HeaderMap) -> Option<t

mod tests {
use super::*;
use std::time::Duration;
// Note we are ignoring some warnings in this module.
// This is due to a long-standing cargo bug that reports imports and helper functions as unused.
// See: https://github.com/rust-lang/rust/issues/46379.
Expand All @@ -491,6 +497,12 @@ mod tests {
std::process::id().to_string()
}

/// Get a request client or panic
#[allow(dead_code)]
fn localhost_client() -> Client {
build_http_client(Duration::from_secs(1), true).expect("failed to create client")
}

#[allow(dead_code)]
async fn enqueue_job(
queue: &PgQueue,
Expand Down Expand Up @@ -565,8 +577,8 @@ mod tests {
webhook_job_parameters.clone(),
webhook_job_metadata,
)
.await
.expect("failed to enqueue job");
.await
.expect("failed to enqueue job");
let worker = WebhookWorker::new(
&worker_id,
&queue,
Expand Down Expand Up @@ -601,15 +613,14 @@ mod tests {
assert!(registry.get_status().healthy)
}

#[sqlx::test(migrations = "../migrations")]
async fn test_send_webhook(_pg: PgPool) {
#[tokio::test]
async fn test_send_webhook() {
let method = HttpMethod::POST;
let url = "http://localhost:18081/echo";
let headers = collections::HashMap::new();
let body = "a very relevant request body";
let client = reqwest::Client::new();

let response = send_webhook(client, &method, url, &headers, body.to_owned())
let response = send_webhook(localhost_client(), &method, url, &headers, body.to_owned())
.await
.expect("send_webhook failed");

Expand All @@ -620,15 +631,14 @@ mod tests {
);
}

#[sqlx::test(migrations = "../migrations")]
async fn test_error_message_contains_response_body(_pg: PgPool) {
#[tokio::test]
async fn test_error_message_contains_response_body() {
let method = HttpMethod::POST;
let url = "http://localhost:18081/fail";
let headers = collections::HashMap::new();
let body = "this is an error message";
let client = reqwest::Client::new();

let err = send_webhook(client, &method, url, &headers, body.to_owned())
let err = send_webhook(localhost_client(), &method, url, &headers, body.to_owned())
.await
.err()
.expect("request didn't fail when it should have failed");
Expand All @@ -645,17 +655,16 @@ mod tests {
}
}

#[sqlx::test(migrations = "../migrations")]
async fn test_error_message_contains_up_to_n_bytes_of_response_body(_pg: PgPool) {
#[tokio::test]
async fn test_error_message_contains_up_to_n_bytes_of_response_body() {
let method = HttpMethod::POST;
let url = "http://localhost:18081/fail";
let headers = collections::HashMap::new();
// This is double the current hardcoded amount of bytes.
// TODO: Make this configurable and change it here too.
let body = (0..20 * 1024).map(|_| "a").collect::<Vec<_>>().concat();
let client = reqwest::Client::new();

let err = send_webhook(client, &method, url, &headers, body.to_owned())
let err = send_webhook(localhost_client(), &method, url, &headers, body.to_owned())
.await
.err()
.expect("request didn't fail when it should have failed");
Expand All @@ -673,4 +682,29 @@ mod tests {
));
}
}

#[tokio::test]
async fn test_private_ips_denied() {
let method = HttpMethod::POST;
let url = "http://localhost:18081/echo";
let headers = collections::HashMap::new();
let body = "a very relevant request body";
let filtering_client =
build_http_client(Duration::from_secs(1), false).expect("failed to create client");

let err = send_webhook(filtering_client, &method, url, &headers, body.to_owned())
.await
.err()
.expect("request didn't fail when it should have failed");

assert!(matches!(err, WebhookError::Request(..)));
if let WebhookError::Request(request_error) = err {
assert_eq!(request_error.status(), None);
assert!(request_error
.to_string()
.contains("No public IPv4 found for specified host"));
} else {
panic!("unexpected error type {:?}", err)
}
}
}

0 comments on commit e040018

Please sign in to comment.