Skip to content

Commit

Permalink
Merge pull request #1880 from ballerina-platform/host-header-fix
Browse files Browse the repository at this point in the history
[Master] Fix overwriting the Host header
  • Loading branch information
TharmiganK authored Mar 5, 2024
2 parents 60ebb95 + abec3eb commit e47a9ab
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// Copyright (c) 2024 WSO2 LLC. (http://www.wso2.org).
//
// WSO2 LLC. licenses this file to you under the Apache License,
// Version 2.0 (the "License"); you may not use this file except
// in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

import ballerina/http;
import ballerina/test;

final http:Client http2ClientHost1 = check new("localhost:" + http2ClientHostHeaderTestPort.toString());
final http:Client http2ClientHost2 = check new("localhost:" + httpClientHostHeaderTestPort.toString());
final http:Client httpClientHost1 = check new("localhost:" + httpClientHostHeaderTestPort.toString(), httpVersion = http:HTTP_1_1);
final http:Client httpClientHost2 = check new("localhost:" + http2ClientHostHeaderTestPort.toString(), httpVersion = http:HTTP_1_1);

service / on new http:Listener(http2ClientHostHeaderTestPort) {

resource function 'default host(http:Request req) returns string|error {
return req.getHeader("Host");
}
}

service / on new http:Listener(httpClientHostHeaderTestPort, httpVersion = http:HTTP_1_1) {

resource function 'default host(http:Request req) returns string|error {
return req.getHeader("Host");
}
}

@test:Config {}
function testHttpClientHostHeader1() returns error? {
string host = check httpClientHost1->/host;
test:assertEquals(host, "localhost:" + httpClientHostHeaderTestPort.toString());

host = check httpClientHost1->get("/host");
test:assertEquals(host, "localhost:" + httpClientHostHeaderTestPort.toString());

host = check httpClientHost2->/host;
test:assertEquals(host, "localhost:" + http2ClientHostHeaderTestPort.toString());

host = check httpClientHost2->get("/host");
test:assertEquals(host, "localhost:" + http2ClientHostHeaderTestPort.toString());
}

@test:Config {}
function testHttpClientHostHeader2() returns error? {
string host = check httpClientHost1->/host.get({"Host": "mock.com"});
test:assertEquals(host, "mock.com");

host = check httpClientHost1->get("/host", {"Host": "mock.com"});
test:assertEquals(host, "mock.com");

host = check httpClientHost2->/host.get({"Host": "mock.com"});
test:assertEquals(host, "mock.com");

host = check httpClientHost2->get("/host", {"Host": "mock.com"});
test:assertEquals(host, "mock.com");
}

@test:Config {}
function testHttpClientHostHeader3() returns error? {
http:Request req = new;
req.setHeader("Host", "mock.com");
string host = check httpClientHost1->/host.post(req);
test:assertEquals(host, "mock.com");

host = check httpClientHost1->post("/host", req, {"Host": "mock2.com"});
test:assertEquals(host, "mock2.com");

host = check httpClientHost2->/host.post(req);
test:assertEquals(host, "mock2.com");

host = check httpClientHost2->post("/host", req, {"Host": "mock3.com"});
test:assertEquals(host, "mock3.com");
}

@test:Config {}
function testHttp2ClientHostHeader1() returns error? {
string host = check http2ClientHost1->/host;
test:assertEquals(host, "localhost:" + http2ClientHostHeaderTestPort.toString());

host = check http2ClientHost1->get("/host");
test:assertEquals(host, "localhost:" + http2ClientHostHeaderTestPort.toString());

host = check http2ClientHost2->/host;
test:assertEquals(host, "localhost:" + httpClientHostHeaderTestPort.toString());

host = check http2ClientHost2->get("/host");
test:assertEquals(host, "localhost:" + httpClientHostHeaderTestPort.toString());
}

@test:Config {}
function testHttp2ClientHostHeader2() returns error? {
string host = check http2ClientHost1->/host.get({"Host": "mock.com"});
test:assertEquals(host, "mock.com");

host = check http2ClientHost1->get("/host", {"Host": "mock.com"});
test:assertEquals(host, "mock.com");

host = check http2ClientHost2->/host.get({"Host": "mock.com"});
test:assertEquals(host, "mock.com");

host = check http2ClientHost2->get("/host", {"Host": "mock.com"});
test:assertEquals(host, "mock.com");
}

@test:Config {}
function testHttp2ClientHostHeader3() returns error? {
http:Request req = new;
req.setHeader("Host", "mock.com");
string host = check http2ClientHost1->/host.post(req);
test:assertEquals(host, "mock.com");

host = check http2ClientHost1->post("/host", req, {"Host": "mock2.com"});
test:assertEquals(host, "mock2.com");

host = check http2ClientHost2->/host.post(req);
test:assertEquals(host, "mock2.com");

host = check http2ClientHost2->post("/host", req, {"Host": "mock3.com"});
test:assertEquals(host, "mock3.com");
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,6 @@ const int clientSchemeTestHttpsListenerTestPort = 9624;

const int clientResourceMethodsTestPort = 9631;
const int clientFormUrlEncodedTestPort = 9604;

const int http2ClientHostHeaderTestPort = 9605;
const int httpClientHostHeaderTestPort = 9606;
6 changes: 3 additions & 3 deletions ballerina/Ballerina.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
org = "ballerina"
name = "http"
version = "2.10.7"
version = "2.10.8"
authors = ["Ballerina"]
keywords = ["http", "network", "service", "listener", "client"]
repository = "https://github.com/ballerina-platform/module-ballerina-http"
Expand All @@ -16,8 +16,8 @@ graalvmCompatible = true
[[platform.java17.dependency]]
groupId = "io.ballerina.stdlib"
artifactId = "http-native"
version = "2.10.7"
path = "../native/build/libs/http-native-2.10.7.jar"
version = "2.10.8"
path = "../native/build/libs/http-native-2.10.8-SNAPSHOT.jar"

[[platform.java17.dependency]]
groupId = "io.ballerina.stdlib"
Expand Down
2 changes: 1 addition & 1 deletion ballerina/CompilerPlugin.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ id = "http-compiler-plugin"
class = "io.ballerina.stdlib.http.compiler.HttpCompilerPlugin"

[[dependency]]
path = "../compiler-plugin/build/libs/http-compiler-plugin-2.10.7.jar"
path = "../compiler-plugin/build/libs/http-compiler-plugin-2.10.8-SNAPSHOT.jar"
6 changes: 3 additions & 3 deletions ballerina/Dependencies.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ modules = [
[[package]]
org = "ballerina"
name = "cache"
version = "3.7.0"
version = "3.7.1"
dependencies = [
{org = "ballerina", name = "constraint"},
{org = "ballerina", name = "jballerina.java"},
Expand Down Expand Up @@ -76,7 +76,7 @@ modules = [
[[package]]
org = "ballerina"
name = "http"
version = "2.10.7"
version = "2.10.8"
dependencies = [
{org = "ballerina", name = "auth"},
{org = "ballerina", name = "cache"},
Expand Down Expand Up @@ -283,7 +283,7 @@ modules = [
[[package]]
org = "ballerina"
name = "observe"
version = "1.2.0"
version = "1.2.2"
dependencies = [
{org = "ballerina", name = "jballerina.java"}
]
Expand Down
6 changes: 6 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ This file contains all the notable changes done to the Ballerina HTTP package th
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to
[Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Added

- [Make the `Host` header overridable](https://github.com/ballerina-platform/ballerina-library/issues/6133)

## [2.10.7] - 2024-02-14

### Fixed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ private static int getStartTime(BObject timestamp) {
}

private static void setHostHeader(String host, int port, HttpHeaders headers) {
if (headers.contains(HttpHeaderNames.HOST)) {
return;
}
if (port == 80 || port == 443) {
headers.set(HttpHeaderNames.HOST, host);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import org.slf4j.Logger;
Expand Down Expand Up @@ -100,7 +101,7 @@ private FullHttpResponse send(FullHttpRequest httpRequest) {
this.responseHandler.setLatch(latch);
this.responseHandler.setWaitForConnectionClosureLatch(this.waitForConnectionClosureLatch);

httpRequest.headers().set(HttpHeaderNames.HOST, host + ":" + port);
addHostHeader(httpRequest);
this.connectedChannel.writeAndFlush(httpRequest);
try {
latch.await();
Expand All @@ -117,7 +118,7 @@ public List<FullHttpResponse> sendExpectContinueRequest(DefaultHttpRequest httpR
this.responseHandler.setLatch(latch);
this.responseHandler.setWaitForConnectionClosureLatch(this.waitForConnectionClosureLatch);

httpRequest.headers().set(HttpHeaderNames.HOST, host + ":" + port);
addHostHeader(httpRequest);
httpRequest.headers().set(HttpHeaderNames.EXPECT, HttpHeaderValues.CONTINUE);
this.connectedChannel.writeAndFlush(httpRequest);

Expand Down Expand Up @@ -149,7 +150,7 @@ public LinkedList<FullHttpResponse> sendTwoInPipeline(FullHttpRequest httpReques
this.responseHandler.setLatch(latch);
this.responseHandler.setWaitForConnectionClosureLatch(this.waitForConnectionClosureLatch);

httpRequest.headers().set(HttpHeaderNames.HOST, host + ":" + port);
addHostHeader(httpRequest);
this.connectedChannel.writeAndFlush(httpRequest.copy());

this.connectedChannel.writeAndFlush(httpRequest);
Expand All @@ -161,6 +162,12 @@ public LinkedList<FullHttpResponse> sendTwoInPipeline(FullHttpRequest httpReques
return this.responseHandler.getHttpFullResponses();
}

private void addHostHeader(HttpRequest httpRequest) {
if (!httpRequest.headers().contains(HttpHeaderNames.HOST)) {
httpRequest.headers().set(HttpHeaderNames.HOST, host + ":" + port);
}
}

/**
* Send pipelined requests to the server.
*
Expand Down Expand Up @@ -236,7 +243,7 @@ public String sendMultiplePipelinedRequests(String path) {
*/
private FullHttpRequest getFullHttpRequest(String path, String requestId) {
FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path);
request.headers().set(HttpHeaderNames.HOST, host + ":" + port);
addHostHeader(request);
request.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE);
request.headers().set("message-id", requestId);
return request;
Expand Down

0 comments on commit e47a9ab

Please sign in to comment.