Skip to content

Commit

Permalink
Add a tag->tag-set migration command
Browse files Browse the repository at this point in the history
  • Loading branch information
ravi-signal committed Dec 6, 2024
1 parent 236b049 commit 1442752
Show file tree
Hide file tree
Showing 7 changed files with 294 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,30 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.function.Consumer;
import javax.annotation.Nonnull;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.subscriptions.PaymentProvider;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Scheduler;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;

public class IssuedReceiptsManager {
private static final Logger log = LoggerFactory.getLogger(IssuedReceiptsManager.class);

public static final String KEY_PROCESSOR_ITEM_ID = "A"; // S (HashKey)
public static final String KEY_ISSUED_RECEIPT_TAG = "B"; // B
Expand Down Expand Up @@ -134,4 +140,55 @@ private byte[] generateHmac(String type, Consumer<Mac> byteConsumer) {
throw new AssertionError(e);
}
}

public CompletableFuture<Void> migrateToTagSet(final IssuedReceipt issuedReceipt) {
UpdateItemRequest updateItemRequest = UpdateItemRequest.builder()
.tableName(table)
.key(Map.of(KEY_PROCESSOR_ITEM_ID, s(issuedReceipt.itemId())))
.conditionExpression("attribute_exists(#key) AND #tag = :tag")
.returnValues(ReturnValue.NONE)
.updateExpression("ADD #tags :singletonTag")
.expressionAttributeNames(Map.of(
"#key", KEY_PROCESSOR_ITEM_ID,
"#tag", KEY_ISSUED_RECEIPT_TAG,
"#tags", KEY_ISSUED_RECEIPT_TAG_SET))
.expressionAttributeValues(Map.of(
":tag", b(issuedReceipt.tag()),
":singletonTag", AttributeValue.fromBs(Collections.singletonList(SdkBytes.fromByteArray(issuedReceipt.tag())))))
.build();
return dynamoDbAsyncClient.updateItem(updateItemRequest)
.thenRun(Util.NOOP)
.exceptionally(ExceptionUtils.exceptionallyHandler(ConditionalCheckFailedException.class, e -> {
log.info("Not migrating item {}, because when we tried to migrate it was already deleted", issuedReceipt.itemId());
return null;
}));
}

public record IssuedReceipt(String itemId, byte[] tag) {}
public Flux<IssuedReceipt> receiptsWithoutTagSet(final int segments, final Scheduler scheduler) {
if (segments < 1) {
throw new IllegalArgumentException("Total number of segments must be positive");
}

return Flux.range(0, segments)
.parallel()
.runOn(scheduler)
.flatMap(segment -> dynamoDbAsyncClient.scanPaginator(ScanRequest.builder()
.tableName(table)
.consistentRead(true)
.segment(segment)
.totalSegments(segments)
.filterExpression("attribute_not_exists(#tags)")
.expressionAttributeNames(Map.of("#tags", KEY_ISSUED_RECEIPT_TAG_SET))
.build())
.items()
.flatMapIterable(item -> {
if (!item.containsKey(KEY_ISSUED_RECEIPT_TAG)) {
log.error("Skipping item {} that was missing a receipt tag", item.get(KEY_PROCESSOR_ITEM_ID).s());
return Collections.emptySet();
}
return List.of(new IssuedReceipt(item.get(KEY_PROCESSOR_ITEM_ID).s(), item.get(KEY_ISSUED_RECEIPT_TAG).b().asByteArray()));
}))
.sequential();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.whispersystems.textsecuregcm.storage.ClientPublicKeys;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.IssuedReceiptsManager;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.MessagesCache;
import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb;
Expand Down Expand Up @@ -89,6 +90,7 @@ record CommandDependencies(
FaultTolerantRedisClusterClient pushSchedulerCluster,
ClientResources.Builder redisClusterClientResourcesBuilder,
BackupManager backupManager,
IssuedReceiptsManager issuedReceiptsManager,
DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
DynamoDbAsyncClient dynamoDbAsyncClient,
PhoneNumberIdentifiers phoneNumberIdentifiers) {
Expand Down Expand Up @@ -261,6 +263,13 @@ static CommandDependencies build(
remoteStorageRetryExecutor,
configuration.getCdn3StorageManagerConfiguration()),
clock);

final IssuedReceiptsManager issuedReceiptsManager = new IssuedReceiptsManager(
configuration.getDynamoDbTables().getIssuedReceipts().getTableName(),
configuration.getDynamoDbTables().getIssuedReceipts().getExpiration(),
dynamoDbAsyncClient,
configuration.getDynamoDbTables().getIssuedReceipts().getGenerator());

APNSender apnSender = new APNSender(apnSenderExecutor, configuration.getApnConfiguration());
FcmSender fcmSender = new FcmSender(fcmSenderExecutor, configuration.getFcmConfiguration().credentials().value());
PushNotificationScheduler pushNotificationScheduler = new PushNotificationScheduler(pushSchedulerCluster,
Expand Down Expand Up @@ -296,6 +305,7 @@ static CommandDependencies build(
pushSchedulerCluster,
redisClientResourcesBuilder,
backupManager,
issuedReceiptsManager,
dynamicConfigurationManager,
dynamoDbAsyncClient,
phoneNumberIdentifiers
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/

package org.whispersystems.textsecuregcm.workers;

import io.dropwizard.core.Application;
import io.dropwizard.core.setup.Environment;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import net.sourceforge.argparse4j.inf.Namespace;
import net.sourceforge.argparse4j.inf.Subparser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.IssuedReceiptsManager;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

import java.time.Clock;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;

public class IssuedReceiptMigrationCommand extends AbstractCommandWithDependencies {

private final Logger logger = LoggerFactory.getLogger(getClass());

private static final String SEGMENT_COUNT_ARGUMENT = "segments";
private static final String DRY_RUN_ARGUMENT = "dry-run";
private static final String MAX_CONCURRENCY_ARGUMENT = "max-concurrency";
private static final String BUFFER_ARGUMENT = "buffer";

private static final String INSPECTED_ISSUED_RECEIPTS = MetricsUtil.name(IssuedReceiptMigrationCommand.class,
"inspectedIssuedReceipts");
private static final String MIGRATED_ISSUED_RECEIPTS = MetricsUtil.name(IssuedReceiptMigrationCommand.class,
"migratedIssuedReceipts");

private final Clock clock;

public IssuedReceiptMigrationCommand(final Clock clock) {
super(new Application<>() {
@Override
public void run(final WhisperServerConfiguration configuration, final Environment environment) {
}
}, "migrate-issued-receipts", "Migrates columns in the issued receipts table");
this.clock = clock;
}

@Override
public void configure(final Subparser subparser) {
super.configure(subparser);

subparser.addArgument("--segments")
.type(Integer.class)
.dest(SEGMENT_COUNT_ARGUMENT)
.required(false)
.setDefault(1)
.help("The total number of segments for a DynamoDB scan");

subparser.addArgument("--max-concurrency")
.type(Integer.class)
.dest(MAX_CONCURRENCY_ARGUMENT)
.required(false)
.setDefault(16)
.help("Max concurrency for migration operations");

subparser.addArgument("--dry-run")
.type(Boolean.class)
.dest(DRY_RUN_ARGUMENT)
.required(false)
.setDefault(true)
.help("If true, don’t actually perform migration");

subparser.addArgument("--buffer")
.type(Integer.class)
.dest(BUFFER_ARGUMENT)
.setDefault(16_384)
.help("Records to buffer");
}

@Override
protected void run(final Environment environment, final Namespace namespace,
final WhisperServerConfiguration configuration, final CommandDependencies commandDependencies) throws Exception {
final int bufferSize = namespace.getInt(BUFFER_ARGUMENT);
final int segments = Objects.requireNonNull(namespace.getInt(SEGMENT_COUNT_ARGUMENT));
final int concurrency = Objects.requireNonNull(namespace.getInt(MAX_CONCURRENCY_ARGUMENT));
final boolean dryRun = namespace.getBoolean(DRY_RUN_ARGUMENT);

logger.info("Crawling issuedReceipts with {} segments and {} processors",
segments,
Runtime.getRuntime().availableProcessors());

final Counter inspected = Metrics.counter(INSPECTED_ISSUED_RECEIPTS,
"dryRun", Boolean.toString(dryRun));
final Counter migrated = Metrics.counter(MIGRATED_ISSUED_RECEIPTS,
"dryRun", Boolean.toString(dryRun));

final IssuedReceiptsManager issuedReceiptsManager = commandDependencies.issuedReceiptsManager();
final Flux<IssuedReceiptsManager.IssuedReceipt> receipts =
issuedReceiptsManager.receiptsWithoutTagSet(segments, Schedulers.parallel());
final long count = bufferShuffle(receipts, bufferSize)
.doOnNext(issuedReceipt -> inspected.increment())
.flatMap(issuedReceipt -> Mono
.fromCompletionStage(() -> dryRun
? CompletableFuture.completedFuture(null)
: issuedReceiptsManager.migrateToTagSet(issuedReceipt))
.thenReturn(true)
.retry(3)
.onErrorResume(throwable -> {
logger.error("Failed to migrate {} after 3 attempts, giving up", issuedReceipt.itemId(), throwable);
return Mono.just(false);
}),
concurrency)
.doOnNext(success ->
Metrics.counter(MIGRATED_ISSUED_RECEIPTS,
"dryRun", Boolean.toString(dryRun),
"success", Boolean.toString(success)))
.count()
.block();
logger.info("Attempted to migrate {} issued receipts", count);
}

private static <T> Flux<T> bufferShuffle(Flux<T> f, int bufferSize) {
return f.buffer(bufferSize)
.map(source -> {
final ArrayList<T> shuffled = new ArrayList<>(source);
Collections.shuffle(shuffled);
return shuffled;
})
.limitRate(2)
.flatMapIterable(Function.identity());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.assertj.core.api.Condition;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand All @@ -26,11 +28,13 @@
import org.whispersystems.textsecuregcm.subscriptions.PaymentProvider;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import reactor.core.scheduler.Schedulers;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemResponse;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;

class IssuedReceiptsManagerTest {

Expand Down Expand Up @@ -89,6 +93,74 @@ void testRecordIssuance() {
assertThat(future).succeedsWithin(Duration.ofSeconds(3));
}

@Test
void testMigrateToTagSet() {
Instant now = Instant.ofEpochSecond(NOW_EPOCH_SECONDS);

issuedReceiptsManager
.recordIssuance("itemId", PaymentProvider.STRIPE, randomReceiptCredentialRequest(), now)
.join();
removeTagSet("itemId");

assertThat(getItem("itemId").item()).doesNotContainKey(IssuedReceiptsManager.KEY_ISSUED_RECEIPT_TAG_SET);

final IssuedReceiptsManager.IssuedReceipt issuedReceipt = issuedReceiptsManager
.receiptsWithoutTagSet(1, Schedulers.immediate())
.blockFirst();

issuedReceiptsManager.migrateToTagSet(issuedReceipt).join();

final Map<String, AttributeValue> item = getItem("itemId").item();
assertThat(item)
.containsKey(IssuedReceiptsManager.KEY_ISSUED_RECEIPT_TAG_SET)
.containsKey(IssuedReceiptsManager.KEY_ISSUED_RECEIPT_TAG);

final List<byte[]> tags = item
.get(IssuedReceiptsManager.KEY_ISSUED_RECEIPT_TAG_SET).bs()
.stream()
.map(SdkBytes::asByteArray)
.toList();
assertThat(tags).hasSize(1);

final byte[] tag = item.get(IssuedReceiptsManager.KEY_ISSUED_RECEIPT_TAG).b().asByteArray();
assertThat(tags).first().isEqualTo(tag);
}


@Test
void testReceiptsWithoutTagSet() {
Instant now = Instant.ofEpochSecond(NOW_EPOCH_SECONDS);

final int numItems = 100;
final List<String> expectedNoTagSet = IntStream.range(0, numItems)
.boxed()
.flatMap(i -> {
final String itemId = "item-%s".formatted(i);
issuedReceiptsManager.recordIssuance(itemId, PaymentProvider.STRIPE, randomReceiptCredentialRequest(), now).join();

if (i % 2 == 0) {
removeTagSet(itemId);
return Stream.of(itemId);
} else {
return Stream.empty();
}
}).toList();
final List<String> items = issuedReceiptsManager
.receiptsWithoutTagSet(1, Schedulers.immediate())
.map(IssuedReceiptsManager.IssuedReceipt::itemId)
.collectList().block();
assertThat(items).hasSize(numItems / 2);
assertThat(items).containsExactlyInAnyOrderElementsOf(expectedNoTagSet);
}

@Test
void testMigrateAfterRecordExpires() {
final IssuedReceiptsManager.IssuedReceipt issued = new IssuedReceiptsManager.IssuedReceipt("itemId",
TestRandomUtil.nextBytes(32));
// We should succeed but do nothing if the item is deleted by the time we try to migrate it
issuedReceiptsManager.migrateToTagSet(issued).join();
assertThat(getItem("itemId").hasItem()).isFalse();
}

private GetItemResponse getItem(final String itemId) {
final DynamoDbClient client = DYNAMO_DB_EXTENSION.getDynamoDbClient();
Expand All @@ -104,4 +176,15 @@ private static ReceiptCredentialRequest randomReceiptCredentialRequest() {
when(request.serialize()).thenReturn(bytes);
return request;
}

private void removeTagSet(final String itemId) {
final DynamoDbClient client = DYNAMO_DB_EXTENSION.getDynamoDbClient();
// Simulate an entry that was written before we wrote the tag set field
client.updateItem(UpdateItemRequest.builder()
.tableName(Tables.ISSUED_RECEIPTS.tableName())
.key(Map.of(IssuedReceiptsManager.KEY_PROCESSOR_ITEM_ID, AttributeValues.s(itemId)))
.updateExpression("REMOVE #tags")
.expressionAttributeNames(Map.of("#tags", IssuedReceiptsManager.KEY_ISSUED_RECEIPT_TAG_SET))
.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ void setUp() {
null,
null,
null,
null,
null);

//noinspection unchecked
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ private TestNotifyIdleDevicesCommand(final MessagesManager messagesManager,
null,
null,
null,
null,
null);

this.idleDeviceNotificationScheduler = idleDeviceNotificationScheduler;
Expand Down
Loading

0 comments on commit 1442752

Please sign in to comment.