Skip to content

Commit

Permalink
Merge branch 'awslabs:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Muthukamalan authored Jan 5, 2025
2 parents dda467b + 1e8c2cc commit d2aca7f
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 3 deletions.
File renamed without changes.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ The readiness check port/path and traffic port can be configured using environme
| AWS_LWA_INVOKE_MODE | Lambda function invoke mode: "buffered" or "response_stream", default is "buffered" | "buffered" |
| AWS_LWA_PASS_THROUGH_PATH | the path for receiving event payloads that are passed through from non-http triggers | "/events" |
| AWS_LWA_AUTHORIZATION_SOURCE | a header name to be replaced to `Authorization` | None |
| AWS_LWA_ERROR_STATUS_CODES | comma-separated list of HTTP status codes that will cause Lambda invocations to fail (e.g. "500,502-504,422") | None |

> **Note:**
> We use "AWS_LWA_" prefix to namespacing all environment variables used by Lambda Web Adapter. The original ones will be supported until we reach version 1.0.
Expand Down Expand Up @@ -137,6 +138,8 @@ Please check out [FastAPI with Response Streaming](examples/fastapi-response-str

**AWS_LWA_AUTHORIZATION_SOURCE** - When set, Lambda Web Adapter replaces the specified header name to `Authorization` before proxying a request. This is useful when you use Lambda function URL with [IAM auth type](https://docs.aws.amazon.com/lambda/latest/dg/urls-auth.html), which reserves Authorization header for IAM authentication, but you want to still use Authorization header for your backend apps. This feature is disabled by default.

**AWS_LWA_ERROR_STATUS_CODES** - A comma-separated list of HTTP status codes that will cause Lambda invocations to fail. Supports individual codes and ranges (e.g. "500,502-504,422"). When the web application returns any of these status codes, the Lambda invocation will fail and trigger error handling behaviors like retries or DLQ processing. This is useful for treating certain HTTP errors as Lambda execution failures. This feature is disabled by default.

## Request Context

**Request Context** is metadata API Gateway sends to Lambda for a request. It usually contains requestId, requestTime, apiId, identity, and authorizer. Identity and authorizer are useful to get client identity for authorization. API Gateway Developer Guide contains more details [here](https://docs.aws.amazon.com/apigateway/latest/developerguide/set-up-lambda-proxy-integrations.html#api-gateway-simple-proxy-for-lambda-input-format).
Expand Down
74 changes: 71 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

mod readiness;

use http::{
header::{HeaderName, HeaderValue},
Method, StatusCode,
Expand All @@ -13,6 +15,7 @@ use lambda_http::request::RequestContext;
use lambda_http::Body;
pub use lambda_http::Error;
use lambda_http::{Request, RequestExt, Response};
use readiness::Checkpoint;
use std::fmt::Debug;
use std::{
env,
Expand All @@ -24,8 +27,7 @@ use std::{
},
time::Duration,
};
use tokio::net::TcpStream;
use tokio::time::timeout;
use tokio::{net::TcpStream, time::timeout};
use tokio_retry::{strategy::FixedInterval, Retry};
use tower::{Service, ServiceBuilder};
use tower_http::compression::CompressionLayer;
Expand Down Expand Up @@ -78,6 +80,7 @@ pub struct AdapterOptions {
pub compression: bool,
pub invoke_mode: LambdaInvokeMode,
pub authorization_source: Option<String>,
pub error_status_codes: Option<Vec<u16>>,
}

impl Default for AdapterOptions {
Expand Down Expand Up @@ -116,10 +119,42 @@ impl Default for AdapterOptions {
.as_str()
.into(),
authorization_source: env::var("AWS_LWA_AUTHORIZATION_SOURCE").ok(),
error_status_codes: env::var("AWS_LWA_ERROR_STATUS_CODES")
.ok()
.map(|codes| parse_status_codes(&codes)),
}
}
}

fn parse_status_codes(input: &str) -> Vec<u16> {
input
.split(',')
.flat_map(|part| {
let part = part.trim();
if part.contains('-') {
let range: Vec<&str> = part.split('-').collect();
if range.len() == 2 {
if let (Ok(start), Ok(end)) = (range[0].parse::<u16>(), range[1].parse::<u16>()) {
return (start..=end).collect::<Vec<_>>();
}
}
tracing::warn!("Failed to parse status code range: {}", part);
vec![]
} else {
part.parse::<u16>().map_or_else(
|_| {
if !part.is_empty() {
tracing::warn!("Failed to parse status code: {}", part);
}
vec![]
},
|code| vec![code],
)
}
})
.collect()
}

#[derive(Clone)]
pub struct Adapter<C, B> {
client: Arc<Client<C, B>>,
Expand All @@ -134,6 +169,7 @@ pub struct Adapter<C, B> {
compression: bool,
invoke_mode: LambdaInvokeMode,
authorization_source: Option<String>,
error_status_codes: Option<Vec<u16>>,
}

impl Adapter<HttpConnector, Body> {
Expand Down Expand Up @@ -171,6 +207,7 @@ impl Adapter<HttpConnector, Body> {
compression: options.compression,
invoke_mode: options.invoke_mode,
authorization_source: options.authorization_source.clone(),
error_status_codes: options.error_status_codes.clone(),
}
}
}
Expand Down Expand Up @@ -231,7 +268,12 @@ impl Adapter<HttpConnector, Body> {
}

async fn is_web_ready(&self, url: &Url, protocol: &Protocol) -> bool {
let mut checkpoint = Checkpoint::new();
Retry::spawn(FixedInterval::from_millis(10), || {
if checkpoint.lapsed() {
tracing::info!(url = %url.to_string(), "app is not ready after {}ms", checkpoint.next_ms());
checkpoint.increment();
}
self.check_web_readiness(url, protocol)
})
.await
Expand All @@ -247,10 +289,11 @@ impl Adapter<HttpConnector, Body> {
&& response.status().as_u16() >= 100
} =>
{
tracing::debug!("app is ready");
Ok(())
}
_ => {
tracing::debug!("app is not ready");
tracing::trace!("app is not ready");
Err(-1)
}
},
Expand Down Expand Up @@ -341,6 +384,17 @@ impl Adapter<HttpConnector, Body> {

let mut app_response = self.client.request(request).await?;

// Check if status code should trigger an error
if let Some(error_codes) = &self.error_status_codes {
let status = app_response.status().as_u16();
if error_codes.contains(&status) {
return Err(Error::from(format!(
"Request failed with configured error status code: {}",
status
)));
}
}

// remove "transfer-encoding" from the response to support "sam local start-api"
app_response.headers_mut().remove("transfer-encoding");

Expand Down Expand Up @@ -373,6 +427,20 @@ mod tests {
use super::*;
use httpmock::{Method::GET, MockServer};

#[test]
fn test_parse_status_codes() {
assert_eq!(parse_status_codes("500,502-504,422"), vec![500, 502, 503, 504, 422]);
assert_eq!(
parse_status_codes("500, 502-504, 422"), // with spaces
vec![500, 502, 503, 504, 422]
);
assert_eq!(parse_status_codes("500"), vec![500]);
assert_eq!(parse_status_codes("500-502"), vec![500, 501, 502]);
assert_eq!(parse_status_codes("invalid"), Vec::<u16>::new());
assert_eq!(parse_status_codes("500-invalid"), Vec::<u16>::new());
assert_eq!(parse_status_codes(""), Vec::<u16>::new());
}

#[tokio::test]
async fn test_status_200_is_ok() {
// Start app server
Expand Down
63 changes: 63 additions & 0 deletions src/readiness.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use std::time::Instant;

pub(crate) struct Checkpoint {
start: Instant,
interval_ms: u128,
next_ms: u128,
}

impl Checkpoint {
pub fn new() -> Checkpoint {
// The default function timeout is 3 seconds. This will alert the users. See #520
let interval_ms = 2000;

let start = Instant::now();
Checkpoint {
start,
interval_ms,
next_ms: start.elapsed().as_millis() + interval_ms,
}
}

pub const fn next_ms(&self) -> u128 {
self.next_ms
}

pub const fn increment(&mut self) {
self.next_ms += self.interval_ms;
}

pub fn lapsed(&self) -> bool {
self.start.elapsed().as_millis() >= self.next_ms
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_checkpoint_new() {
let checkpoint = Checkpoint::new();
assert_eq!(checkpoint.next_ms(), 2000);
assert!(!checkpoint.lapsed());
}

#[test]
fn test_checkpoint_increment() {
let mut checkpoint = Checkpoint::new();
checkpoint.increment();
assert_eq!(checkpoint.next_ms(), 4000);
assert!(!checkpoint.lapsed());
}

#[test]
fn test_checkpoint_lapsed() {
let checkpoint = Checkpoint {
start: Instant::now(),
interval_ms: 0,
next_ms: 0,
};
assert!(checkpoint.lapsed());
}
}
32 changes: 32 additions & 0 deletions tests/integ_tests/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,38 @@ async fn test_http_content_encoding_suffix() {
assert_eq!(json_data.to_owned(), body_to_string(response).await);
}

#[tokio::test]
async fn test_http_error_status_codes() {
// Start app server
let app_server = MockServer::start();
let error_endpoint = app_server.mock(|when, then| {
when.method(GET).path("/error");
then.status(502).body("Bad Gateway");
});

// Initialize adapter with error status codes
let mut adapter = Adapter::new(&AdapterOptions {
host: app_server.host(),
port: app_server.port().to_string(),
readiness_check_port: app_server.port().to_string(),
readiness_check_path: "/healthcheck".to_string(),
error_status_codes: Some(vec![500, 502, 503, 504]),
..Default::default()
});

// Call the adapter service with request that should trigger error
let req = LambdaEventBuilder::new().with_path("/error").build();
let mut request = Request::from(req);
add_lambda_context_to_request(&mut request);

let result = adapter.call(request).await;
assert!(result.is_err(), "Expected error response for status code 502");
assert!(result.unwrap_err().to_string().contains("502"));

// Assert endpoint was called
error_endpoint.assert();
}

#[tokio::test]
async fn test_http_authorization_source() {
// Start app server
Expand Down

0 comments on commit d2aca7f

Please sign in to comment.