Skip to content

Commit

Permalink
ALS-6330: Add unit tests, some minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ramari16 committed Jun 17, 2024
1 parent 396606b commit a798763
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public interface GenomicProcessor {

VariantMask createMaskForPatientSet(Set<Integer> patientSubset);

Mono<Collection<String>> getVariantList(DistributableQuery distributableQuery);
Mono<Set<String>> getVariantList(DistributableQuery distributableQuery);

List<String> getPatientIds();

Expand All @@ -29,5 +29,6 @@ public interface GenomicProcessor {

List<InfoColumnMeta> getInfoColumnMeta();

// todo: make the map value a Set instead of array
Map<String, String[]> getVariantMetadata(Collection<String> variantList);
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public VariantMask createMaskForPatientSet(Set<Integer> patientSubset) {
}

@Override
public Mono<Collection<String>> getVariantList(DistributableQuery distributableQuery) {
public Mono<Set<String>> getVariantList(DistributableQuery distributableQuery) {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,10 @@ private VariantIndex addVariantsForInfoFilter(VariantIndex unionOfInfoFilters, Q
}

@Override
public Mono<Collection<String>> getVariantList(DistributableQuery query) {
public Mono<Set<String>> getVariantList(DistributableQuery query) {
return Mono.fromCallable(() -> runGetVariantList(query)).subscribeOn(Schedulers.boundedElastic());
}
public Collection<String> runGetVariantList(DistributableQuery query) {
public Set<String> runGetVariantList(DistributableQuery query) {
boolean queryContainsVariantInfoFilters = query.getVariantInfoFilters().stream().anyMatch(variantInfoFilter ->
!variantInfoFilter.categoryVariantInfoFilters.isEmpty() || !variantInfoFilter.numericVariantInfoFilters.isEmpty()
);
Expand Down Expand Up @@ -297,7 +297,7 @@ public Collection<String> runGetVariantList(DistributableQuery query) {
return unionOfInfoFiltersVariantSpecs;
}
}
return new ArrayList<>();
return new HashSet<>();
}

private VariantMask getIdSetForVariantSpecCategoryFilter(String[] zygosities, String key, VariantBucketHolder<VariableVariantMasks> bucketCache) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import java.math.BigInteger;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -61,19 +62,19 @@ public Set<Integer> patientMaskToPatientIdSet(VariantMask patientMask) {

@Override
public VariantMask createMaskForPatientSet(Set<Integer> patientSubset) {
VariantMask result = nodes.parallelStream()
// all nodes have the same patient set --
VariantMask result = nodes.stream().findFirst()
.map(node -> node.createMaskForPatientSet(patientSubset))
.reduce(VariantMask::union)
.orElseGet(VariantMask::emptyInstance);
return result;
}

@Override
public Mono<Collection<String>> getVariantList(DistributableQuery distributableQuery) {
Mono<Collection<String>> result = Flux.just(nodes.toArray(GenomicProcessor[]::new))
public Mono<Set<String>> getVariantList(DistributableQuery distributableQuery) {
Mono<Set<String>> result = Flux.just(nodes.toArray(GenomicProcessor[]::new))
.flatMap(node -> node.getVariantList(distributableQuery))
.reduce((variantList1, variantList2) -> {
List<String> mergedResult = new ArrayList<>(variantList1.size() + variantList2.size());
Set<String> mergedResult = new HashSet<>(variantList1.size() + variantList2.size());
mergedResult.addAll(variantList1);
mergedResult.addAll(variantList2);
return mergedResult;
Expand Down Expand Up @@ -141,13 +142,26 @@ public List<InfoColumnMeta> getInfoColumnMeta() {

@Override
public Map<String, String[]> getVariantMetadata(Collection<String> variantList) {
return nodes.parallelStream()
// this is overly complicated because of the array type.
// todo: update this when we change the method signature from array to set
ConcurrentHashMap<String, Set<String>> result = new ConcurrentHashMap<>();
nodes.stream()
.map(node -> node.getVariantMetadata(variantList))
.reduce((p1, p2) -> {
Map<String, String[]> mapCopy = new HashMap<>(p1);
mapCopy.putAll(p2);
return mapCopy;
}).orElseGet(Map::of);
.forEach(variantMap -> {
variantMap.entrySet().forEach(entry -> {
Set<String> metadata = result.get(entry.getKey());
if (metadata != null) {
metadata.addAll(Set.of(entry.getValue()));
} else {
result.put(entry.getKey(), new HashSet<>(Set.of(entry.getValue())));
}
});
});
return result.entrySet().stream()
.collect(Collectors.toMap(
Map.Entry::getKey,
entry -> entry.getValue().toArray(new String[] {})
));
}

private List<InfoColumnMeta> initInfoColumnsMeta() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ public VariantMask createMaskForPatientSet(Set<Integer> patientSubset) {
}

@Override
public Mono<Collection<String>> getVariantList(DistributableQuery distributableQuery) {
Mono<Collection<String>> result = Flux.just(nodes.toArray(GenomicProcessor[]::new))
public Mono<Set<String>> getVariantList(DistributableQuery distributableQuery) {
Mono<Set<String>> result = Flux.just(nodes.toArray(GenomicProcessor[]::new))
.flatMap(node -> node.getVariantList(distributableQuery))
.reduce((variantList1, variantList2) -> {
List<String> mergedResult = new ArrayList<>(variantList1.size() + variantList2.size());
Set<String> mergedResult = new HashSet<>(variantList1.size() + variantList2.size());
mergedResult.addAll(variantList1);
mergedResult.addAll(variantList2);
return mergedResult;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import edu.harvard.hms.dbmi.avillach.hpds.data.genotype.InfoColumnMeta;
import edu.harvard.hms.dbmi.avillach.hpds.data.genotype.VariableVariantMasks;
import edu.harvard.hms.dbmi.avillach.hpds.data.genotype.VariantMask;
import edu.harvard.hms.dbmi.avillach.hpds.data.genotype.VariantMasks;
import edu.harvard.hms.dbmi.avillach.hpds.data.genotype.caching.VariantBucketHolder;
import edu.harvard.hms.dbmi.avillach.hpds.processing.DistributableQuery;
import edu.harvard.hms.dbmi.avillach.hpds.processing.GenomicProcessor;
Expand All @@ -13,14 +12,13 @@
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono;

import java.math.BigInteger;
import java.util.*;

public class GenomicProcessorRestClient implements GenomicProcessor {

private final WebClient webClient;

private static final ParameterizedTypeReference<Collection<String>> VARIANT_LIST_TYPE_REFERENCE = new ParameterizedTypeReference<>(){};
private static final ParameterizedTypeReference<Set<String>> VARIANT_SET_TYPE_REFERENCE = new ParameterizedTypeReference<>(){};
private static final ParameterizedTypeReference<List<InfoColumnMeta>> INFO_COLUMNS_META_TYPE_REFERENCE = new ParameterizedTypeReference<>(){};
private static final ParameterizedTypeReference<List<String>> LIST_OF_STRING_TYPE_REFERENCE = new ParameterizedTypeReference<>(){};
private static final ParameterizedTypeReference<Set<String>> SET_OF_STRING_TYPE_REFERENCE = new ParameterizedTypeReference<>(){};
Expand Down Expand Up @@ -55,13 +53,13 @@ public VariantMask createMaskForPatientSet(Set<Integer> patientSubset) {

@SuppressWarnings("unchecked")
@Override
public Mono<Collection<String>> getVariantList(DistributableQuery distributableQuery) {
Mono<Collection<String>> result = webClient.post()
public Mono<Set<String>> getVariantList(DistributableQuery distributableQuery) {
Mono<Set<String>> result = webClient.post()
.uri("/variants")
.contentType(MediaType.APPLICATION_JSON)
.body(Mono.just(distributableQuery), DistributableQuery.class)
.retrieve()
.bodyToMono(VARIANT_LIST_TYPE_REFERENCE);
.bodyToMono(VARIANT_SET_TYPE_REFERENCE);
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@

import edu.harvard.hms.dbmi.avillach.hpds.data.genotype.VariantMask;
import edu.harvard.hms.dbmi.avillach.hpds.data.genotype.VariantMaskBitmaskImpl;
import edu.harvard.hms.dbmi.avillach.hpds.data.genotype.VariantMaskSparseImpl;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import reactor.core.publisher.Mono;

import java.math.BigInteger;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.junit.jupiter.api.Assertions.*;

import static org.mockito.Mockito.when;
import static org.mockito.Mockito.*;

@ExtendWith(MockitoExtension.class)
class GenomicProcessorParentImplTest {
Expand All @@ -26,6 +30,13 @@ class GenomicProcessorParentImplTest {

private GenomicProcessorParentImpl parentProcessor;

@BeforeEach
public void setup() {
parentProcessor = new GenomicProcessorParentImpl(List.of(
mockProcessor1, mockProcessor2, mockProcessor3
));
}

@Test
public void patientIdInit_patientsMatch_noException() {
when(mockProcessor1.getPatientIds()).thenReturn(List.of("1", "42", "99"));
Expand Down Expand Up @@ -55,12 +66,79 @@ public void getPatientMask_validResponses_returnMerged() {
when(mockProcessor1.getPatientMask(distributableQuery)).thenReturn(Mono.just(new VariantMaskBitmaskImpl(new BigInteger("110110000011", 2))));
when(mockProcessor2.getPatientMask(distributableQuery)).thenReturn(Mono.just(new VariantMaskBitmaskImpl(new BigInteger("110001100011", 2))));
when(mockProcessor3.getPatientMask(distributableQuery)).thenReturn(Mono.just(new VariantMaskBitmaskImpl(new BigInteger("110000000111", 2))));
parentProcessor = new GenomicProcessorParentImpl(List.of(
mockProcessor1, mockProcessor2, mockProcessor3
));

VariantMask patientMask = parentProcessor.getPatientMask(distributableQuery).block();
VariantMask expectedPatientMask = new VariantMaskBitmaskImpl(new BigInteger("110111100111", 2));
assertEquals(expectedPatientMask, patientMask);
}
@Test
public void getPatientMask_oneNode_returnPatients() {
DistributableQuery distributableQuery = new DistributableQuery();
when(mockProcessor1.getPatientMask(distributableQuery)).thenReturn(Mono.just(new VariantMaskBitmaskImpl(new BigInteger("110110000011", 2))));
parentProcessor = new GenomicProcessorParentImpl(List.of(mockProcessor1));

VariantMask patientMask = parentProcessor.getPatientMask(distributableQuery).block();
VariantMask expectedPatientMask = new VariantMaskBitmaskImpl(new BigInteger("110110000011", 2));
assertEquals(expectedPatientMask, patientMask);
}

@Test
public void createMaskForPatientSet_oneNode_returnFirst() {
Set<Integer> patientSet = Set.of(7, 8, 9);
when(mockProcessor1.createMaskForPatientSet(patientSet)).thenReturn(new VariantMaskBitmaskImpl(new BigInteger("110100000011", 2)));

parentProcessor = new GenomicProcessorParentImpl(List.of(mockProcessor1));

VariantMask patientMask = parentProcessor.createMaskForPatientSet(patientSet);
VariantMask expectedPatientMask = new VariantMaskBitmaskImpl(new BigInteger("110100000011", 2));
assertEquals(expectedPatientMask, patientMask);
}

@Test
public void createMaskForPatientSet_multipleNodes_returnFirst() {
Set<Integer> patientSet = Set.of(7, 8, 9);
when(mockProcessor1.createMaskForPatientSet(patientSet)).thenReturn(new VariantMaskBitmaskImpl(new BigInteger("110100000011", 2)));

VariantMask patientMask = parentProcessor.createMaskForPatientSet(patientSet);
VariantMask expectedPatientMask = new VariantMaskBitmaskImpl(new BigInteger("110100000011", 2));
assertEquals(expectedPatientMask, patientMask);
// this should just call the first node, since all nodes have identical patient sets
verify(mockProcessor2, never()).createMaskForPatientSet(any());
verify(mockProcessor3, never()).createMaskForPatientSet(any());
}

@Test
public void getVariantList_overlappingVariants_mergeCorrectly() {
DistributableQuery distributableQuery = new DistributableQuery();
when(mockProcessor1.getVariantList(distributableQuery)).thenReturn(Mono.just(Set.of("variant1", "variant2")));
when(mockProcessor2.getVariantList(distributableQuery)).thenReturn(Mono.just(Set.of("variant2", "variant3")));
when(mockProcessor3.getVariantList(distributableQuery)).thenReturn(Mono.just(Set.of("variant3", "variant4")));

Set<String> variantList = parentProcessor.getVariantList(distributableQuery).block();
assertEquals(Set.of("variant1", "variant2", "variant3", "variant4"), variantList);
}

@Test
public void getVariantList_oneNode_returnVariants() {
DistributableQuery distributableQuery = new DistributableQuery();
when(mockProcessor1.getVariantList(distributableQuery)).thenReturn(Mono.just(Set.of("variant1", "variant2")));

parentProcessor = new GenomicProcessorParentImpl(List.of(mockProcessor1));

Set<String> variantList = parentProcessor.getVariantList(distributableQuery).block();
assertEquals(Set.of("variant1", "variant2"), variantList);
}

@Test
public void getVariantMetadata_overlappingVariants_mergedCorrectly() {
List<String> variantList = List.of("variant1", "variant2", "variant3");
when(mockProcessor1.getVariantMetadata(variantList)).thenReturn(Map.of("variant1", new String[]{"metadata1", "metadata2"}));
when(mockProcessor2.getVariantMetadata(variantList)).thenReturn(Map.of("variant1", new String[]{"metadata1", "metadata3"}));
when(mockProcessor3.getVariantMetadata(variantList)).thenReturn(Map.of("variant3", new String[]{"metadata31", "metadata32"}));

Map<String, String[]> variantMetadata = parentProcessor.getVariantMetadata(variantList);
assertEquals(Set.of("metadata1", "metadata2", "metadata3"), Set.of(variantMetadata.get("variant1")));
assertEquals(Set.of("metadata31", "metadata32"), Set.of(variantMetadata.get("variant3")));
assertEquals(2, variantMetadata.size());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import edu.harvard.hms.dbmi.avillach.hpds.data.genotype.VariantMask;
import edu.harvard.hms.dbmi.avillach.hpds.data.genotype.VariantMaskBitmaskImpl;
import edu.harvard.hms.dbmi.avillach.hpds.data.genotype.VariantMaskSparseImpl;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand All @@ -13,6 +14,7 @@

import java.math.BigInteger;
import java.util.List;
import java.util.Set;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -114,4 +116,33 @@ public void patientIdInit_multipleInvalidPatients_warnMessage(CapturedOutput out

assertTrue(output.getOut().contains("3 duplicate patients found in patient partitions"));
}


@Test
public void createMaskForPatientSet_validResponses_returnMerged() {
Set<Integer> patientSubset = Set.of(2, 3, 8, 9, 15);
when(mockProcessor1.createMaskForPatientSet(patientSubset)).thenReturn(new VariantMaskBitmaskImpl(new BigInteger("11011011", 2)));
when(mockProcessor1.getPatientIds()).thenReturn(List.of("1", "2", "3", "4"));
when(mockProcessor2.createMaskForPatientSet(patientSubset)).thenReturn(new VariantMaskSparseImpl(Set.of(3, 4)));
when(mockProcessor2.getPatientIds()).thenReturn(List.of("5", "6", "7", "8", "9", "10", "11", "12"));
when(mockProcessor3.createMaskForPatientSet(patientSubset)).thenReturn(new VariantMaskBitmaskImpl(new BigInteger("11000111", 2)));
when(mockProcessor3.getPatientIds()).thenReturn(List.of("15", "16", "17", "18"));
VariantMask patientMask = patientMergingParent.createMaskForPatientSet(patientSubset);
VariantMask expectedPatientMask = new VariantMaskBitmaskImpl(new BigInteger("11000100011000011011", 2));
assertEquals(expectedPatientMask, patientMask);
}

@Test
public void createMaskForPatientSet_validResponsesOneEmpty_returnMerged() {
Set<Integer> patientSubset = Set.of(2, 3, 15);
when(mockProcessor1.createMaskForPatientSet(patientSubset)).thenReturn(new VariantMaskBitmaskImpl(new BigInteger("11011011", 2)));
when(mockProcessor1.getPatientIds()).thenReturn(List.of("1", "2", "3", "4"));
when(mockProcessor2.createMaskForPatientSet(patientSubset)).thenReturn(VariantMask.emptyInstance());
when(mockProcessor2.getPatientIds()).thenReturn(List.of("5", "6", "7", "8", "9", "10", "11", "12"));
when(mockProcessor3.createMaskForPatientSet(patientSubset)).thenReturn(new VariantMaskBitmaskImpl(new BigInteger("11000111", 2)));
when(mockProcessor3.getPatientIds()).thenReturn(List.of("15", "16", "17", "18"));
VariantMask patientMask = patientMergingParent.createMaskForPatientSet(patientSubset);
VariantMask expectedPatientMask = new VariantMaskBitmaskImpl(new BigInteger("11000100000000011011", 2));
assertEquals(expectedPatientMask, patientMask);
}
}

0 comments on commit a798763

Please sign in to comment.