From c82a93951035dca2c1830f16e1e910d01f9ffa6d Mon Sep 17 00:00:00 2001 From: Wojciech Lukowicz Date: Tue, 5 Nov 2024 14:19:32 +0000 Subject: [PATCH] fix fragmented message length calculation for padding (fixes #527) The code to calculate required length was using the wrong HEADER_LENGTH constant and undercounting it. In unlucky cases it might have looked like a message would fit in the remaining space at the end of the term, while actually it wouldn't. So instead of writing padding and then the fragmented message, the first fragment would get written, then padding, and then the remaining fragments in the next term. On the subscription side, a fragment assembler would ignore such messages. Fixes: 75b9fdea1 ("fix term padding at the end of term buffers") --- .../artio/protocol/GatewayPublication.java | 10 +- .../protocol/GatewayPublicationTest.java | 199 ++++++++++++++++++ 2 files changed, 201 insertions(+), 8 deletions(-) create mode 100644 artio-core/src/test/java/uk/co/real_logic/artio/protocol/GatewayPublicationTest.java diff --git a/artio-core/src/main/java/uk/co/real_logic/artio/protocol/GatewayPublication.java b/artio-core/src/main/java/uk/co/real_logic/artio/protocol/GatewayPublication.java index 79d3393619..f2e868f165 100644 --- a/artio-core/src/main/java/uk/co/real_logic/artio/protocol/GatewayPublication.java +++ b/artio-core/src/main/java/uk/co/real_logic/artio/protocol/GatewayPublication.java @@ -38,12 +38,11 @@ import java.util.List; -import static io.aeron.logbuffer.FrameDescriptor.FRAME_ALIGNMENT; +import static io.aeron.logbuffer.LogBufferDescriptor.computeFragmentedFrameLength; import static io.aeron.protocol.DataHeaderFlyweight.BEGIN_FLAG; import static io.aeron.protocol.DataHeaderFlyweight.END_FLAG; import static java.nio.ByteOrder.LITTLE_ENDIAN; import static java.nio.charset.StandardCharsets.UTF_8; -import static org.agrona.BitUtil.align; import static uk.co.real_logic.artio.DebugLogger.*; import static uk.co.real_logic.artio.LogTag.*; import static uk.co.real_logic.artio.messages.ErrorDecoder.messageHeaderLength; @@ -311,12 +310,7 @@ public long saveMessage( if (fragmented) { // Add a padding message at the end of the term buffer if needed. - final int length = framedLength; - final int numMaxPayloads = length / maxPayloadLength; - final int remainingPayload = length % maxPayloadLength; - final int lastFrameLength = remainingPayload > 0 ? - align(remainingPayload + HEADER_LENGTH, FRAME_ALIGNMENT) : 0; - final int requiredLength = (numMaxPayloads * (maxPayloadLength + HEADER_LENGTH)) + lastFrameLength; + final int requiredLength = computeFragmentedFrameLength(framedLength, maxPayloadLength); final int termLength = dataPublication.termBufferLength(); final int termOffset = dataPublication.termOffset(); final int resultingOffset = termOffset + requiredLength; diff --git a/artio-core/src/test/java/uk/co/real_logic/artio/protocol/GatewayPublicationTest.java b/artio-core/src/test/java/uk/co/real_logic/artio/protocol/GatewayPublicationTest.java new file mode 100644 index 0000000000..1082e3cd49 --- /dev/null +++ b/artio-core/src/test/java/uk/co/real_logic/artio/protocol/GatewayPublicationTest.java @@ -0,0 +1,199 @@ +package uk.co.real_logic.artio.protocol; + +import io.aeron.*; +import io.aeron.driver.MediaDriver; +import io.aeron.logbuffer.ControlledFragmentHandler; +import io.aeron.logbuffer.FragmentHandler; +import io.aeron.logbuffer.Header; +import io.aeron.protocol.DataHeaderFlyweight; +import org.agrona.DirectBuffer; +import org.agrona.collections.MutableLong; +import org.agrona.concurrent.NoOpIdleStrategy; +import org.agrona.concurrent.SystemEpochNanoClock; +import org.agrona.concurrent.UnsafeBuffer; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import uk.co.real_logic.artio.messages.DisconnectReason; +import uk.co.real_logic.artio.messages.MessageStatus; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.IntStream; + +import static io.aeron.logbuffer.FrameDescriptor.FRAME_ALIGNMENT; +import static org.junit.jupiter.api.Assertions.*; +import static uk.co.real_logic.artio.TestFixtures.mediaDriverContext; +import static uk.co.real_logic.artio.protocol.GatewayPublication.FRAMED_MESSAGE_SIZE; + +class GatewayPublicationTest +{ + private static final int MAX_UNFRAGMENTED_BODY_LENGTH = 1305; + + static IntStream bodyLengthRange() + { + return IntStream.rangeClosed( + MAX_UNFRAGMENTED_BODY_LENGTH, + MAX_UNFRAGMENTED_BODY_LENGTH + DataHeaderFlyweight.HEADER_LENGTH + ); + } + + @ParameterizedTest + @MethodSource("bodyLengthRange") + void testSavingMessagesOverTermBoundary(final int bodyLength) + { + final int termBufferLength = 64 * 1024; + try ( + MediaDriver driver = MediaDriver.launch(mediaDriverContext(termBufferLength, true)); + Aeron aeron = Aeron.connect(new Aeron.Context().aeronDirectoryName(driver.aeronDirectoryName()))) + { + final String channel = CommonContext.IPC_CHANNEL; + final int streamId = 1000; + + final Subscription subscription = aeron.addSubscription(channel, streamId); + final ExclusivePublication publication = aeron.addExclusivePublication(channel, streamId); + final Counter fails = aeron.addCounter(1001, "fails"); + + // first one won't be fragmented, but subsequent ones will + assertEquals(MAX_UNFRAGMENTED_BODY_LENGTH, publication.maxPayloadLength() - FRAMED_MESSAGE_SIZE); + + final GatewayPublication gatewayPublication = new GatewayPublication( + publication, + fails, + NoOpIdleStrategy.INSTANCE, + new SystemEpochNanoClock(), + 5 + ); + + // leave enough space in the term for just over one full frame + final int startingPosition = (termBufferLength - driver.context().ipcMtuLength() - 1) & + -DataHeaderFlyweight.HEADER_LENGTH; // align down to min frame length + advanceToPosition(startingPosition, publication, subscription); + + final byte[] body = new byte[bodyLength]; + ThreadLocalRandom.current().nextBytes(body); + final DirectBuffer srcBuffer = new UnsafeBuffer(body); + + int attempt = 1; + do + { + final long result = gatewayPublication.saveMessage( + srcBuffer, + 0, + body.length, + 5000, + 68, + 1, + 0, + 1234, + MessageStatus.OK, + 42 + ); + if (result > 0) + { + break; + } + if (attempt >= 2) + { + fail("failed to save message: " + result); + } + attempt++; + } + while (true); + + final MessageCapturingProtocolHandler protocolHandler = new MessageCapturingProtocolHandler(); + final ProtocolSubscription protocolSubscription = ProtocolSubscription.of(protocolHandler); + final ControlledFragmentHandler fragmentHandler = new ControlledFragmentAssembler(protocolSubscription); + + subscription.controlledPoll(fragmentHandler, 5); + subscription.controlledPoll(fragmentHandler, 5); + + final CapturedMessage capturedMessage = protocolHandler.capturedMessages.get(0); + assertArrayEquals(body, capturedMessage.body()); + assertEquals(68, capturedMessage.messageType()); + assertEquals(42, capturedMessage.sequenceNumber()); + } + } + + private void advanceToPosition( + final long position, + final ExclusivePublication publication, + final Subscription subscription) + { + if (position % FRAME_ALIGNMENT != 0) + { + fail("position is not frame aligned: " + position); + } + + long lastPubPos = 0; + final MutableLong lastSubPos = new MutableLong(); + final FragmentHandler fragmentHandler = (buffer1, offset, length, header) -> lastSubPos.set(header.position()); + final DirectBuffer buffer = new UnsafeBuffer(); + + while (lastPubPos < position || lastSubPos.get() < position) + { + if (lastPubPos < position) + { + final long result = publication.offer(buffer); + if (result > 0) + { + lastPubPos = result; + } + } + + subscription.poll(fragmentHandler, 5); + } + } + + private record CapturedMessage(byte[] body, long messageType, int sequenceNumber) + { + } + + private static final class MessageCapturingProtocolHandler implements ProtocolHandler + { + private final List capturedMessages = new ArrayList<>(); + + public ControlledFragmentHandler.Action onMessage( + final DirectBuffer buffer, + final int offset, + final int length, + final int libraryId, + final long connectionId, + final long sessionId, + final int sequenceIndex, + final long messageType, + final long timestamp, + final MessageStatus status, + final int sequenceNumber, + final Header header, + final int metaDataLength) + { + final byte[] body = new byte[length]; + buffer.getBytes(offset, body); + + capturedMessages.add(new CapturedMessage( + body, + messageType, + sequenceNumber + )); + + return ControlledFragmentHandler.Action.CONTINUE; + } + + public ControlledFragmentHandler.Action onDisconnect( + final int libraryId, + final long connectionId, + final DisconnectReason reason) + { + throw new IllegalStateException(); + } + + public ControlledFragmentHandler.Action onFixPMessage( + final long connectionId, + final DirectBuffer buffer, + final int offset) + { + throw new IllegalStateException(); + } + } +}