Skip to content

Commit

Permalink
fix fragmented message length calculation for padding (fixes #527)
Browse files Browse the repository at this point in the history
The code to calculate required length was using the wrong HEADER_LENGTH
constant and undercounting it. In unlucky cases it might have looked like
a message would fit in the remaining space at the end of the term, while
actually it wouldn't. So instead of writing padding and then the fragmented
message, the first fragment would get written, then padding, and then the
remaining fragments in the next term. On the subscription side, a fragment
assembler would ignore such messages.

Fixes: 75b9fde ("fix term padding at the end of term buffers")
  • Loading branch information
wojciech-adaptive committed Nov 5, 2024
1 parent 7c8ee4e commit c82a939
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,11 @@

import java.util.List;

import static io.aeron.logbuffer.FrameDescriptor.FRAME_ALIGNMENT;
import static io.aeron.logbuffer.LogBufferDescriptor.computeFragmentedFrameLength;
import static io.aeron.protocol.DataHeaderFlyweight.BEGIN_FLAG;
import static io.aeron.protocol.DataHeaderFlyweight.END_FLAG;
import static java.nio.ByteOrder.LITTLE_ENDIAN;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.agrona.BitUtil.align;
import static uk.co.real_logic.artio.DebugLogger.*;
import static uk.co.real_logic.artio.LogTag.*;
import static uk.co.real_logic.artio.messages.ErrorDecoder.messageHeaderLength;
Expand Down Expand Up @@ -311,12 +310,7 @@ public long saveMessage(
if (fragmented)
{
// Add a padding message at the end of the term buffer if needed.
final int length = framedLength;
final int numMaxPayloads = length / maxPayloadLength;
final int remainingPayload = length % maxPayloadLength;
final int lastFrameLength = remainingPayload > 0 ?
align(remainingPayload + HEADER_LENGTH, FRAME_ALIGNMENT) : 0;
final int requiredLength = (numMaxPayloads * (maxPayloadLength + HEADER_LENGTH)) + lastFrameLength;
final int requiredLength = computeFragmentedFrameLength(framedLength, maxPayloadLength);
final int termLength = dataPublication.termBufferLength();
final int termOffset = dataPublication.termOffset();
final int resultingOffset = termOffset + requiredLength;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
package uk.co.real_logic.artio.protocol;

import io.aeron.*;
import io.aeron.driver.MediaDriver;
import io.aeron.logbuffer.ControlledFragmentHandler;
import io.aeron.logbuffer.FragmentHandler;
import io.aeron.logbuffer.Header;
import io.aeron.protocol.DataHeaderFlyweight;
import org.agrona.DirectBuffer;
import org.agrona.collections.MutableLong;
import org.agrona.concurrent.NoOpIdleStrategy;
import org.agrona.concurrent.SystemEpochNanoClock;
import org.agrona.concurrent.UnsafeBuffer;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import uk.co.real_logic.artio.messages.DisconnectReason;
import uk.co.real_logic.artio.messages.MessageStatus;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.IntStream;

import static io.aeron.logbuffer.FrameDescriptor.FRAME_ALIGNMENT;
import static org.junit.jupiter.api.Assertions.*;
import static uk.co.real_logic.artio.TestFixtures.mediaDriverContext;
import static uk.co.real_logic.artio.protocol.GatewayPublication.FRAMED_MESSAGE_SIZE;

class GatewayPublicationTest
{
private static final int MAX_UNFRAGMENTED_BODY_LENGTH = 1305;

static IntStream bodyLengthRange()
{
return IntStream.rangeClosed(
MAX_UNFRAGMENTED_BODY_LENGTH,
MAX_UNFRAGMENTED_BODY_LENGTH + DataHeaderFlyweight.HEADER_LENGTH
);
}

@ParameterizedTest
@MethodSource("bodyLengthRange")
void testSavingMessagesOverTermBoundary(final int bodyLength)
{
final int termBufferLength = 64 * 1024;
try (
MediaDriver driver = MediaDriver.launch(mediaDriverContext(termBufferLength, true));
Aeron aeron = Aeron.connect(new Aeron.Context().aeronDirectoryName(driver.aeronDirectoryName())))
{
final String channel = CommonContext.IPC_CHANNEL;
final int streamId = 1000;

final Subscription subscription = aeron.addSubscription(channel, streamId);
final ExclusivePublication publication = aeron.addExclusivePublication(channel, streamId);
final Counter fails = aeron.addCounter(1001, "fails");

// first one won't be fragmented, but subsequent ones will
assertEquals(MAX_UNFRAGMENTED_BODY_LENGTH, publication.maxPayloadLength() - FRAMED_MESSAGE_SIZE);

final GatewayPublication gatewayPublication = new GatewayPublication(
publication,
fails,
NoOpIdleStrategy.INSTANCE,
new SystemEpochNanoClock(),
5
);

// leave enough space in the term for just over one full frame
final int startingPosition = (termBufferLength - driver.context().ipcMtuLength() - 1) &
-DataHeaderFlyweight.HEADER_LENGTH; // align down to min frame length
advanceToPosition(startingPosition, publication, subscription);

final byte[] body = new byte[bodyLength];
ThreadLocalRandom.current().nextBytes(body);
final DirectBuffer srcBuffer = new UnsafeBuffer(body);

int attempt = 1;
do
{
final long result = gatewayPublication.saveMessage(
srcBuffer,
0,
body.length,
5000,
68,
1,
0,
1234,
MessageStatus.OK,
42
);
if (result > 0)
{
break;
}
if (attempt >= 2)
{
fail("failed to save message: " + result);
}
attempt++;
}
while (true);

final MessageCapturingProtocolHandler protocolHandler = new MessageCapturingProtocolHandler();
final ProtocolSubscription protocolSubscription = ProtocolSubscription.of(protocolHandler);
final ControlledFragmentHandler fragmentHandler = new ControlledFragmentAssembler(protocolSubscription);

subscription.controlledPoll(fragmentHandler, 5);
subscription.controlledPoll(fragmentHandler, 5);

final CapturedMessage capturedMessage = protocolHandler.capturedMessages.get(0);
assertArrayEquals(body, capturedMessage.body());
assertEquals(68, capturedMessage.messageType());
assertEquals(42, capturedMessage.sequenceNumber());
}
}

private void advanceToPosition(
final long position,
final ExclusivePublication publication,
final Subscription subscription)
{
if (position % FRAME_ALIGNMENT != 0)
{
fail("position is not frame aligned: " + position);
}

long lastPubPos = 0;
final MutableLong lastSubPos = new MutableLong();
final FragmentHandler fragmentHandler = (buffer1, offset, length, header) -> lastSubPos.set(header.position());
final DirectBuffer buffer = new UnsafeBuffer();

while (lastPubPos < position || lastSubPos.get() < position)
{
if (lastPubPos < position)
{
final long result = publication.offer(buffer);
if (result > 0)
{
lastPubPos = result;
}
}

subscription.poll(fragmentHandler, 5);
}
}

private record CapturedMessage(byte[] body, long messageType, int sequenceNumber)
{
}

private static final class MessageCapturingProtocolHandler implements ProtocolHandler
{
private final List<CapturedMessage> capturedMessages = new ArrayList<>();

public ControlledFragmentHandler.Action onMessage(
final DirectBuffer buffer,
final int offset,
final int length,
final int libraryId,
final long connectionId,
final long sessionId,
final int sequenceIndex,
final long messageType,
final long timestamp,
final MessageStatus status,
final int sequenceNumber,
final Header header,
final int metaDataLength)
{
final byte[] body = new byte[length];
buffer.getBytes(offset, body);

capturedMessages.add(new CapturedMessage(
body,
messageType,
sequenceNumber
));

return ControlledFragmentHandler.Action.CONTINUE;
}

public ControlledFragmentHandler.Action onDisconnect(
final int libraryId,
final long connectionId,
final DisconnectReason reason)
{
throw new IllegalStateException();
}

public ControlledFragmentHandler.Action onFixPMessage(
final long connectionId,
final DirectBuffer buffer,
final int offset)
{
throw new IllegalStateException();
}
}
}

0 comments on commit c82a939

Please sign in to comment.