Skip to content

Commit

Permalink
Ensure HttpClient sends full request when the send function does no…
Browse files Browse the repository at this point in the history
…t change `NettyOutbound` (#3526)

Fixes #3524
  • Loading branch information
violetagg authored Dec 5, 2024
1 parent aadb840 commit c43f7da
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ public interface RequestSender extends ResponseReceiver<RequestSender> {
*
* @param sender a bifunction given the outgoing request and the sending
* {@link NettyOutbound}, returns a publisher that will terminate the request
* body on complete
* body on complete. Return {@link Mono#empty()} in case of a request without body.
*
* @return a new {@link ResponseReceiver}
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,13 @@ Publisher<Void> requestWithBody(HttpClientOperations ch) {
}

ch.redirectRequestConsumer(consumer);
return handler != null ? handler.apply(ch, ch) : ch.send();
if (handler != null) {
Publisher<Void> publisher = handler.apply(ch, ch);
return ch.equals(publisher) ? ch.send() : publisher;
}
else {
return ch.send();
}
}
catch (Throwable t) {
return Mono.error(t);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,8 @@
*/
package reactor.netty.http;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http2.Http2Connection;
import io.netty.handler.codec.http2.Http2DataFrame;
import io.netty.handler.codec.http2.Http2FrameCodec;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.ssl.util.SelfSignedCertificate;
Expand All @@ -31,7 +25,6 @@
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mockito;
import org.reactivestreams.Publisher;
import reactor.core.Disposable;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
Expand All @@ -40,7 +33,6 @@
import reactor.core.scheduler.Schedulers;
import reactor.netty.BaseHttpTest;
import reactor.netty.ByteBufFlux;
import reactor.netty.ByteBufMono;
import reactor.netty.http.client.HttpClient;
import reactor.netty.http.server.HttpServer;
import reactor.netty.internal.shaded.reactor.pool.PoolAcquireTimeoutException;
Expand All @@ -58,7 +50,6 @@
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import java.util.function.Predicate;
import java.util.stream.IntStream;
Expand Down Expand Up @@ -368,64 +359,6 @@ private static void doTestHttpClientDefaultSslProvider(HttpClient client) {
assertThat(channel.get()).isTrue();
}

@Test
void testMonoRequestBodySentAsFullRequest_Flux() {
// sends the message and then last http content
doTestMonoRequestBodySentAsFullRequest(ByteBufFlux.fromString(Mono.just("test")), 2);
}

@Test
void testMonoRequestBodySentAsFullRequest_Mono() {
// sends "full" request
doTestMonoRequestBodySentAsFullRequest(ByteBufMono.fromString(Mono.just("test")), 1);
}

@SuppressWarnings("FutureReturnValueIgnored")
private void doTestMonoRequestBodySentAsFullRequest(Publisher<? extends ByteBuf> body, int expectedMsg) {
Http2SslContextSpec serverCtx = Http2SslContextSpec.forServer(ssc.certificate(), ssc.privateKey());
Http2SslContextSpec clientCtx =
Http2SslContextSpec.forClient()
.configure(builder -> builder.trustManager(InsecureTrustManagerFactory.INSTANCE));

disposableServer =
createServer()
.protocol(HttpProtocol.H2)
.secure(spec -> spec.sslContext(serverCtx))
.handle((req, res) -> req.receive()
.then(res.send()))
.bindNow(Duration.ofSeconds(30));

AtomicInteger counter = new AtomicInteger();
createClient(disposableServer.port())
.protocol(HttpProtocol.H2)
.secure(spec -> spec.sslContext(clientCtx))
.doOnRequest((req, conn) -> {
ChannelPipeline pipeline = conn.channel().parent().pipeline();
ChannelHandlerContext ctx = pipeline.context(Http2FrameCodec.class);
if (ctx != null) {
pipeline.addAfter(ctx.name(), "testMonoRequestBodySentAsFullRequest",
new ChannelOutboundHandlerAdapter() {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
if (msg instanceof Http2DataFrame) {
counter.getAndIncrement();
}
//"FutureReturnValueIgnored" this is deliberate
ctx.write(msg, promise);
}
});
}
})
.post()
.uri("/")
.send(body)
.responseContent()
.aggregate()
.block(Duration.ofSeconds(30));

assertThat(counter.get()).isEqualTo(expectedMsg);
}

@Test
void testIssue1394_SchemeHttpConfiguredH2CNegotiatedH2C() {
// "prior-knowledge" is used and stream id is 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@
import ch.qos.logback.classic.spi.ILoggingEvent;
import ch.qos.logback.core.read.ListAppender;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
Expand Down Expand Up @@ -78,13 +82,15 @@
import java.lang.annotation.Target;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.charset.Charset;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;
Expand Down Expand Up @@ -1008,6 +1014,114 @@ void testProtocolVersion(HttpServer server, HttpClient client) {
.verify(Duration.ofSeconds(5));
}

@ParameterizedCompatibleCombinationsTest
void testMonoRequestBodySentAsFullRequest_Flux(HttpServer server, HttpClient client) {
testRequestBody(server, client, sender -> sender.send(ByteBufFlux.fromString(Mono.just("test"))), 2);
}

@ParameterizedCompatibleCombinationsTest
void testMonoRequestBodySentAsFullRequest_Mono(HttpServer server, HttpClient client) {
// sends "full" request
testRequestBody(server, client, sender -> sender.send(ByteBufMono.fromString(Mono.just("test"))), 1);
}

@ParameterizedCompatibleCombinationsTest
void testMonoRequestBodySentAsFullRequest_MonoEmpty(HttpServer server, HttpClient client) {
// sends "full" request
testRequestBody(server, client, sender -> sender.send(Mono.empty()), 1);
}

@ParameterizedCompatibleCombinationsTest
void testIssue3524Flux(HttpServer server, HttpClient client) {
// sends the message and then last http content
testRequestBody(server, client, sender -> sender.send((req, out) -> out.sendString(Flux.just("te", "st"))), 3);
}

@ParameterizedCompatibleCombinationsTest
void testIssue3524Mono(HttpServer server, HttpClient client) {
// sends "full" request
testRequestBody(server, client, sender -> sender.send((req, out) -> out.sendString(Mono.just("test"))), 1);
}

@ParameterizedCompatibleCombinationsTest
void testIssue3524MonoEmpty(HttpServer server, HttpClient client) {
// sends "full" request
testRequestBody(server, client, sender -> sender.send((req, out) -> Mono.empty()), 1);
}

@ParameterizedCompatibleCombinationsTest
void testIssue3524NoBody(HttpServer server, HttpClient client) {
// sends "full" request
testRequestBody(server, client, sender -> sender.send((req, out) -> out), 1);
}

@ParameterizedCompatibleCombinationsTest
void testIssue3524Object(HttpServer server, HttpClient client) {
// sends "full" request
testRequestBody(server, client,
sender -> sender.send((req, out) -> out.sendObject(Unpooled.wrappedBuffer("test".getBytes(Charset.defaultCharset())))), 1);
}

@SuppressWarnings("FutureReturnValueIgnored")
private void testRequestBody(HttpServer server, HttpClient client,
Function<HttpClient.RequestSender, HttpClient.ResponseReceiver<?>> sendFunction, int expectedMsg) {
disposableServer =
server.handle((req, res) -> req.receive()
.then(res.send()))
.bindNow(Duration.ofSeconds(30));

AtomicInteger counter = new AtomicInteger();
sendFunction.apply(
client.port(disposableServer.port())
.doOnRequest((req, conn) -> {
ChannelPipeline pipeline = conn.channel() instanceof Http2StreamChannel ?
conn.channel().parent().pipeline() : conn.channel().pipeline();
ChannelHandlerContext ctx = pipeline.context(NettyPipeline.HttpCodec);
if (ctx == null) {
ctx = pipeline.context(HttpClientCodec.class);
}
if (ctx != null) {
pipeline.addAfter(ctx.name(), "testRequestBody",
new ChannelOutboundHandlerAdapter() {
boolean done;

@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
if (!done) {
if (msg instanceof Http2HeadersFrame && ((Http2HeadersFrame) msg).isEndStream()) {
done = true;
counter.getAndIncrement();
}
else if (msg instanceof Http2DataFrame) {
if (((Http2DataFrame) msg).isEndStream()) {
done = true;
}
counter.getAndIncrement();
}
else if (msg instanceof LastHttpContent) {
done = true;
counter.getAndIncrement();
}
else if (msg instanceof ByteBuf) {
counter.getAndIncrement();
}
}
//"FutureReturnValueIgnored" this is deliberate
ctx.write(msg, promise);
}
});
}
})
.post()
.uri("/"))
.responseContent()
.aggregate()
.asString()
.block(Duration.ofSeconds(30));

assertThat(counter.get()).isEqualTo(expectedMsg);
}

static final class IdleTimeoutTestChannelInboundHandler extends ChannelInboundHandlerAdapter {

final CountDownLatch latch = new CountDownLatch(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import reactor.core.publisher.Signal;
import reactor.netty.BaseHttpTest;
import reactor.netty.Connection;
import reactor.netty.ConnectionObserver;
import reactor.netty.DisposableServer;
import reactor.netty.LogTracker;
import reactor.netty.http.HttpProtocol;
Expand Down Expand Up @@ -538,9 +539,9 @@ static DisposableServer createServer(EventsRecorder recorder, HttpProtocol proto
ch.pipeline().addBefore(HttpTrafficHandler, "eventsRecorderHandler", new EventsRecorderHandler(recorder));
}
})
.doOnConnection(conn -> {
conn.onTerminate().subscribe(null, null, recorder::recordOnTerminateIsReceived);
if (protocol == HttpProtocol.H2C) {
.doOnConnection(conn -> conn.onTerminate().subscribe(null, null, recorder::recordOnTerminateIsReceived))
.childObserve((conn, state) -> {
if (state == ConnectionObserver.State.CONNECTED && protocol == HttpProtocol.H2C) {
conn.channel().pipeline().addBefore(HttpTrafficHandler, "eventsRecorderHandler", new EventsRecorderHandler(recorder));
}
})
Expand Down

0 comments on commit c43f7da

Please sign in to comment.