diff --git a/.cargo/config b/.cargo/config.toml similarity index 100% rename from .cargo/config rename to .cargo/config.toml diff --git a/README.md b/README.md index 6ae10bcb..c85d279f 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,7 @@ The readiness check port/path and traffic port can be configured using environme | AWS_LWA_INVOKE_MODE | Lambda function invoke mode: "buffered" or "response_stream", default is "buffered" | "buffered" | | AWS_LWA_PASS_THROUGH_PATH | the path for receiving event payloads that are passed through from non-http triggers | "/events" | | AWS_LWA_AUTHORIZATION_SOURCE | a header name to be replaced to `Authorization` | None | +| AWS_LWA_ERROR_STATUS_CODES | comma-separated list of HTTP status codes that will cause Lambda invocations to fail (e.g. "500,502-504,422") | None | > **Note:** > We use "AWS_LWA_" prefix to namespacing all environment variables used by Lambda Web Adapter. The original ones will be supported until we reach version 1.0. @@ -137,6 +138,8 @@ Please check out [FastAPI with Response Streaming](examples/fastapi-response-str **AWS_LWA_AUTHORIZATION_SOURCE** - When set, Lambda Web Adapter replaces the specified header name to `Authorization` before proxying a request. This is useful when you use Lambda function URL with [IAM auth type](https://docs.aws.amazon.com/lambda/latest/dg/urls-auth.html), which reserves Authorization header for IAM authentication, but you want to still use Authorization header for your backend apps. This feature is disabled by default. +**AWS_LWA_ERROR_STATUS_CODES** - A comma-separated list of HTTP status codes that will cause Lambda invocations to fail. Supports individual codes and ranges (e.g. "500,502-504,422"). When the web application returns any of these status codes, the Lambda invocation will fail and trigger error handling behaviors like retries or DLQ processing. This is useful for treating certain HTTP errors as Lambda execution failures. This feature is disabled by default. + ## Request Context **Request Context** is metadata API Gateway sends to Lambda for a request. It usually contains requestId, requestTime, apiId, identity, and authorizer. Identity and authorizer are useful to get client identity for authorization. API Gateway Developer Guide contains more details [here](https://docs.aws.amazon.com/apigateway/latest/developerguide/set-up-lambda-proxy-integrations.html#api-gateway-simple-proxy-for-lambda-input-format). diff --git a/src/lib.rs b/src/lib.rs index dd64118c..733c9703 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -78,6 +78,7 @@ pub struct AdapterOptions { pub compression: bool, pub invoke_mode: LambdaInvokeMode, pub authorization_source: Option, + pub error_status_codes: Option>, } impl Default for AdapterOptions { @@ -116,10 +117,42 @@ impl Default for AdapterOptions { .as_str() .into(), authorization_source: env::var("AWS_LWA_AUTHORIZATION_SOURCE").ok(), + error_status_codes: env::var("AWS_LWA_ERROR_STATUS_CODES") + .ok() + .map(|codes| parse_status_codes(&codes)), } } } +fn parse_status_codes(input: &str) -> Vec { + input + .split(',') + .flat_map(|part| { + let part = part.trim(); + if part.contains('-') { + let range: Vec<&str> = part.split('-').collect(); + if range.len() == 2 { + if let (Ok(start), Ok(end)) = (range[0].parse::(), range[1].parse::()) { + return (start..=end).collect::>(); + } + } + tracing::warn!("Failed to parse status code range: {}", part); + vec![] + } else { + part.parse::().map_or_else( + |_| { + if !part.is_empty() { + tracing::warn!("Failed to parse status code: {}", part); + } + vec![] + }, + |code| vec![code], + ) + } + }) + .collect() +} + #[derive(Clone)] pub struct Adapter { client: Arc>, @@ -134,6 +167,7 @@ pub struct Adapter { compression: bool, invoke_mode: LambdaInvokeMode, authorization_source: Option, + error_status_codes: Option>, } impl Adapter { @@ -171,6 +205,7 @@ impl Adapter { compression: options.compression, invoke_mode: options.invoke_mode, authorization_source: options.authorization_source.clone(), + error_status_codes: options.error_status_codes.clone(), } } } @@ -341,6 +376,17 @@ impl Adapter { let mut app_response = self.client.request(request).await?; + // Check if status code should trigger an error + if let Some(error_codes) = &self.error_status_codes { + let status = app_response.status().as_u16(); + if error_codes.contains(&status) { + return Err(Error::from(format!( + "Request failed with configured error status code: {}", + status + ))); + } + } + // remove "transfer-encoding" from the response to support "sam local start-api" app_response.headers_mut().remove("transfer-encoding"); @@ -373,6 +419,20 @@ mod tests { use super::*; use httpmock::{Method::GET, MockServer}; + #[test] + fn test_parse_status_codes() { + assert_eq!(parse_status_codes("500,502-504,422"), vec![500, 502, 503, 504, 422]); + assert_eq!( + parse_status_codes("500, 502-504, 422"), // with spaces + vec![500, 502, 503, 504, 422] + ); + assert_eq!(parse_status_codes("500"), vec![500]); + assert_eq!(parse_status_codes("500-502"), vec![500, 501, 502]); + assert_eq!(parse_status_codes("invalid"), Vec::::new()); + assert_eq!(parse_status_codes("500-invalid"), Vec::::new()); + assert_eq!(parse_status_codes(""), Vec::::new()); + } + #[tokio::test] async fn test_status_200_is_ok() { // Start app server diff --git a/tests/integ_tests/main.rs b/tests/integ_tests/main.rs index 46fd7c68..7ba449ba 100644 --- a/tests/integ_tests/main.rs +++ b/tests/integ_tests/main.rs @@ -611,6 +611,38 @@ async fn test_http_content_encoding_suffix() { assert_eq!(json_data.to_owned(), body_to_string(response).await); } +#[tokio::test] +async fn test_http_error_status_codes() { + // Start app server + let app_server = MockServer::start(); + let error_endpoint = app_server.mock(|when, then| { + when.method(GET).path("/error"); + then.status(502).body("Bad Gateway"); + }); + + // Initialize adapter with error status codes + let mut adapter = Adapter::new(&AdapterOptions { + host: app_server.host(), + port: app_server.port().to_string(), + readiness_check_port: app_server.port().to_string(), + readiness_check_path: "/healthcheck".to_string(), + error_status_codes: Some(vec![500, 502, 503, 504]), + ..Default::default() + }); + + // Call the adapter service with request that should trigger error + let req = LambdaEventBuilder::new().with_path("/error").build(); + let mut request = Request::from(req); + add_lambda_context_to_request(&mut request); + + let result = adapter.call(request).await; + assert!(result.is_err(), "Expected error response for status code 502"); + assert!(result.unwrap_err().to_string().contains("502")); + + // Assert endpoint was called + error_endpoint.assert(); +} + #[tokio::test] async fn test_http_authorization_source() { // Start app server