Skip to content

Commit

Permalink
Properly respond with sec-websocket-protocol under http/2 (#3141)
Browse files Browse the repository at this point in the history
  • Loading branch information
coolreader18 authored Feb 1, 2025
1 parent 0e6e96f commit 6c9cabf
Showing 1 changed file with 32 additions and 18 deletions.
50 changes: 32 additions & 18 deletions axum/src/extract/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,34 +338,38 @@ impl<F> WebSocketUpgrade<F> {
callback(socket).await;
});

if let Some(sec_websocket_key) = &self.sec_websocket_key {
let mut response = if let Some(sec_websocket_key) = &self.sec_websocket_key {
// If `sec_websocket_key` was `Some`, we are using HTTP/1.1.

#[allow(clippy::declare_interior_mutable_const)]
const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
#[allow(clippy::declare_interior_mutable_const)]
const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");

let mut builder = Response::builder()
Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::CONNECTION, UPGRADE)
.header(header::UPGRADE, WEBSOCKET)
.header(
header::SEC_WEBSOCKET_ACCEPT,
sign(sec_websocket_key.as_bytes()),
);

if let Some(protocol) = self.protocol {
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
}

builder.body(Body::empty()).unwrap()
)
.body(Body::empty())
.unwrap()
} else {
// Otherwise, we are HTTP/2+. As established in RFC 9113 section 8.5, we just respond
// with a 2XX with an empty body:
// <https://datatracker.ietf.org/doc/html/rfc9113#name-the-connect-method>.
Response::new(Body::empty())
};

if let Some(protocol) = self.protocol {
response
.headers_mut()
.insert(header::SEC_WEBSOCKET_PROTOCOL, protocol);
}

response
}
}

Expand Down Expand Up @@ -1092,10 +1096,11 @@ mod tests {
#[crate::test]
async fn integration_test() {
let addr = spawn_service(echo_app());
let (socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo"))
.await
.unwrap();
test_echo_app(socket).await;
let uri = format!("ws://{addr}/echo").try_into().unwrap();
let req = tungstenite::client::ClientRequestBuilder::new(uri)
.with_sub_protocol(TEST_ECHO_APP_REQ_SUBPROTO);
let (socket, response) = tokio_tungstenite::connect_async(req).await.unwrap();
test_echo_app(socket, response.headers()).await;
}

#[crate::test]
Expand Down Expand Up @@ -1123,21 +1128,22 @@ mod tests {
.extension(hyper::ext::Protocol::from_static("websocket"))
.uri("/echo")
.header("sec-websocket-version", "13")
.header("sec-websocket-protocol", TEST_ECHO_APP_REQ_SUBPROTO)
.header("Host", "server.example.com")
.body(Body::empty())
.unwrap();

let response = send_request.send_request(req).await.unwrap();
let mut response = send_request.send_request(req).await.unwrap();
let status = response.status();
if status != 200 {
let body = response.into_body().collect().await.unwrap().to_bytes();
let body = std::str::from_utf8(&body).unwrap();
panic!("response status was {status}: {body}");
}
let upgraded = hyper::upgrade::on(response).await.unwrap();
let upgraded = hyper::upgrade::on(&mut response).await.unwrap();
let upgraded = TokioIo::new(upgraded);
let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Client, None).await;
test_echo_app(socket).await;
test_echo_app(socket, response.headers()).await;
}

fn echo_app() -> Router {
Expand All @@ -1158,11 +1164,19 @@ mod tests {

Router::new().route(
"/echo",
any(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))),
any(|ws: WebSocketUpgrade| {
ready(ws.protocols(["echo2", "echo"]).on_upgrade(handle_socket))
}),
)
}

async fn test_echo_app<S: AsyncRead + AsyncWrite + Unpin>(mut socket: WebSocketStream<S>) {
const TEST_ECHO_APP_REQ_SUBPROTO: &str = "echo3, echo";
async fn test_echo_app<S: AsyncRead + AsyncWrite + Unpin>(
mut socket: WebSocketStream<S>,
headers: &http::HeaderMap,
) {
assert_eq!(headers[http::header::SEC_WEBSOCKET_PROTOCOL], "echo");

let input = tungstenite::Message::Text(tungstenite::Utf8Bytes::from_static("foobar"));
socket.send(input.clone()).await.unwrap();
let output = socket.next().await.unwrap().unwrap();
Expand Down

0 comments on commit 6c9cabf

Please sign in to comment.