Skip to content

Commit

Permalink
ALS-5905: Cleanup patient merging genomic parent, add tests for that.…
Browse files Browse the repository at this point in the history
… Cleanup a bunch of other tests
  • Loading branch information
ramari16 committed Feb 16, 2024
1 parent 94db00b commit 7b4991e
Show file tree
Hide file tree
Showing 11 changed files with 272 additions and 125 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@


import edu.harvard.hms.dbmi.avillach.hpds.etl.genotype.NewVCFLoader;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.test.context.event.annotation.BeforeTestClass;
Expand Down Expand Up @@ -66,7 +67,7 @@ public class BucketIndexBySampleTest {
Set<String> variantSet;
List<Integer> patientSet;

@BeforeTestClass
@BeforeAll
public static void initializeBinfile() throws Exception {
//load variant data
NewVCFLoader.main(new String[] {VCF_INDEX_FILE, STORAGE_DIR, MERGED_DIR});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import edu.harvard.hms.dbmi.avillach.hpds.data.genotype.caching.VariantBucketHolder;
import edu.harvard.hms.dbmi.avillach.hpds.etl.genotype.VariantMetadataLoader;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.springframework.test.context.event.annotation.BeforeTestClass;

Expand Down Expand Up @@ -37,7 +38,7 @@ public class VariantMetadataIndexTest {
private static final String spec5 = "4,9856624,CAAAAA,CA"; private static final String spec5Info = "AC=3033;AF=6.05631e-01;NS=2504;AN=5008;EAS_AF=5.23800e-01;EUR_AF=7.54500e-01;AFR_AF=4.28900e-01;AMR_AF=7.82400e-01;SAS_AF=6.50300e-01;DP=20851;VT=INDEL";


@BeforeTestClass
@BeforeAll
public static void initializeBinfile() throws Exception {
VariantMetadataLoader.main(new String[] {"./src/test/resources/test_vcfIndex.tsv", binFile, "target/VariantMetadataStorage.bin"});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

import java.math.BigInteger;
import java.util.*;
Expand All @@ -34,12 +35,16 @@ public Set<String> load(String conceptPath) {
}
});

private List<InfoColumnMeta> infoColumnsMeta;

private List<String> patientIds;
private final List<InfoColumnMeta> infoColumnsMeta;
private final List<String> patientIds;
private final Set<String> infoStoreColumns;

public GenomicProcessorParentImpl(List<GenomicProcessor> nodes) {
this.nodes = nodes;

patientIds = initializePatientIds();
infoStoreColumns = initializeInfoStoreColumns();
infoColumnsMeta = initInfoColumnsMeta();
}

@Override
Expand Down Expand Up @@ -83,14 +88,26 @@ public Mono<Collection<String>> getVariantList(DistributableQuery distributableQ

@Override
public List<String> getPatientIds() {
if (patientIds != null) {
return patientIds;
} else {
// todo: verify all nodes have the same potients
List<String> result = nodes.get(0).getPatientIds();
patientIds = result;
return result;
}
return patientIds;
}

private List<String> initializePatientIds() {
List<String> patientIds = Flux.just(nodes.toArray(GenomicProcessor[]::new))
.flatMap(node -> Mono.fromCallable(node::getPatientIds).subscribeOn(Schedulers.boundedElastic()))
.reduce((patientIds1, patientIds2) -> {
if (patientIds1.size() != patientIds2.size()) {
throw new IllegalStateException("Patient lists from partitions do not match");
} else {
for (int i = 0; i < patientIds1.size(); i++) {
if (!patientIds1.get(i).equals(patientIds2.get(i))) {
throw new IllegalStateException("Patient lists from partitions do not match");
}
}
}
return patientIds1;
}).block();

return patientIds;
}

@Override
Expand All @@ -107,7 +124,10 @@ public Optional<VariantMasks> getMasks(String path, VariantBucketHolder<VariantM

@Override
public Set<String> getInfoStoreColumns() {
// todo: cache this
return infoStoreColumns;
}

private Set<String> initializeInfoStoreColumns() {
return nodes.parallelStream()
.map(GenomicProcessor::getInfoStoreColumns)
.flatMap(Set::stream)
Expand All @@ -121,10 +141,14 @@ public Set<String> getInfoStoreValues(String conceptPath) {

@Override
public List<InfoColumnMeta> getInfoColumnMeta() {
// todo: initialize on startup?
if (infoColumnsMeta == null) {
infoColumnsMeta = nodes.get(0).getInfoColumnMeta();
}
return infoColumnsMeta;
}

private List<InfoColumnMeta> initInfoColumnsMeta() {
return nodes.parallelStream()
.map(GenomicProcessor::getInfoColumnMeta)
.map(HashSet::new)
.flatMap(Set::stream)
.collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableList;
import edu.harvard.hms.dbmi.avillach.hpds.data.genotype.InfoColumnMeta;
import edu.harvard.hms.dbmi.avillach.hpds.data.genotype.VariantMasks;
import edu.harvard.hms.dbmi.avillach.hpds.data.genotype.caching.VariantBucketHolder;
Expand Down Expand Up @@ -32,12 +33,17 @@ public Set<String> load(String conceptPath) {
}
});

private List<InfoColumnMeta> infoColumnsMeta;
private final List<InfoColumnMeta> infoColumnsMeta;

private List<String> patientIds;
private final List<String> patientIds;
private final Set<String> infoStoreColumns;

public GenomicProcessorPatientMergingParentImpl(List<GenomicProcessor> nodes) {
this.nodes = nodes;

patientIds = initializePatientIds();
infoStoreColumns = initializeInfoStoreColumns();
infoColumnsMeta = initInfoColumnsMeta();
}

@Override
Expand Down Expand Up @@ -90,20 +96,23 @@ public Mono<Collection<String>> getVariantList(DistributableQuery distributableQ

@Override
public List<String> getPatientIds() {
if (patientIds != null) {
return patientIds;
} else {
// todo: verify all nodes have distinct patients
List<String> result = Flux.just(nodes.toArray(GenomicProcessor[]::new))
.flatMapSequential(node -> Mono.fromCallable(node::getPatientIds).subscribeOn(Schedulers.boundedElastic()))
.reduce((list1, list2) -> {
List<String> concatenatedList = new ArrayList<>(list1);
concatenatedList.addAll(list2);
return concatenatedList;
}).block();
patientIds = result;
return result;
return patientIds;
}

private List<String> initializePatientIds() {
List<String> result = Flux.just(nodes.toArray(GenomicProcessor[]::new))
.flatMapSequential(node -> Mono.fromCallable(node::getPatientIds).subscribeOn(Schedulers.boundedElastic()))
.reduce((list1, list2) -> {
List<String> concatenatedList = new ArrayList<>(list1);
concatenatedList.addAll(list2);
return concatenatedList;
}).block();
Set<String> distinctPatientIds = new HashSet<>(result);
if (distinctPatientIds.size() != result.size()) {
log.warn((result.size() - distinctPatientIds.size()) + " duplicate patients found in patient partitions");
}
log.info(distinctPatientIds.size() + " patient ids loaded from patient partitions");
return ImmutableList.copyOf(result);
}

@Override
Expand All @@ -114,7 +123,10 @@ public Optional<VariantMasks> getMasks(String path, VariantBucketHolder<VariantM

@Override
public Set<String> getInfoStoreColumns() {
// todo: cache this
return infoStoreColumns;
}

private Set<String> initializeInfoStoreColumns() {
return nodes.parallelStream()
.map(GenomicProcessor::getInfoStoreColumns)
.flatMap(Set::stream)
Expand All @@ -128,10 +140,14 @@ public Set<String> getInfoStoreValues(String conceptPath) {

@Override
public List<InfoColumnMeta> getInfoColumnMeta() {
// todo: initialize on startup?
if (infoColumnsMeta == null) {
infoColumnsMeta = nodes.get(0).getInfoColumnMeta();
}
return infoColumnsMeta;
}

private List<InfoColumnMeta> initInfoColumnsMeta() {
return nodes.parallelStream()
.map(GenomicProcessor::getInfoColumnMeta)
.map(HashSet::new)
.flatMap(Set::stream)
.collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,56 +3,37 @@

import edu.harvard.hms.dbmi.avillach.hpds.data.genotype.FileBackedByteIndexedInfoStore;
import edu.harvard.hms.dbmi.avillach.hpds.data.query.Query;
import edu.harvard.hms.dbmi.avillach.hpds.storage.FileBackedByteIndexedStorage;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import reactor.core.publisher.Mono;

import java.math.BigInteger;
import java.util.*;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.mockito.Mockito.*;

@ExtendWith(MockitoExtension.class)
public class AbstractProcessorTest {

private AbstractProcessor abstractProcessor;

private Map<String, FileBackedByteIndexedInfoStore> infoStores;

@Mock
private VariantService variantService;
private Map<String, FileBackedByteIndexedInfoStore> infoStores;

@Mock
private GenomicProcessor genomicProcessor;

public static final String GENE_WITH_VARIANT_KEY = "Gene_with_variant";
private static final String VARIANT_SEVERITY_KEY = "Variant_severity";
public static final List<String> EXAMPLE_GENES_WITH_VARIANT = List.of("CDH8", "CDH9", "CDH10");
public static final List<String> EXAMPLE_VARIANT_SEVERITIES = List.of("HIGH", "MODERATE", "LOW");


@BeforeEach
public void setup() {
FileBackedByteIndexedInfoStore mockInfoStore = mock(FileBackedByteIndexedInfoStore.class);
FileBackedByteIndexedStorage<String, Integer[]> mockIndexedStorage = mock(FileBackedByteIndexedStorage.class);
when(mockIndexedStorage.keys()).thenReturn(new HashSet<>(EXAMPLE_GENES_WITH_VARIANT));
when(mockInfoStore.getAllValues()).thenReturn(mockIndexedStorage);

FileBackedByteIndexedInfoStore mockInfoStore2 = mock(FileBackedByteIndexedInfoStore.class);
FileBackedByteIndexedStorage<String, Integer[]> mockIndexedStorage2 = mock(FileBackedByteIndexedStorage.class);
when(mockIndexedStorage2.keys()).thenReturn(new HashSet<>(EXAMPLE_VARIANT_SEVERITIES));
when(mockInfoStore2.getAllValues()).thenReturn(mockIndexedStorage2);

infoStores = Map.of(
GENE_WITH_VARIANT_KEY, mockInfoStore,
VARIANT_SEVERITY_KEY, mockInfoStore2
);

abstractProcessor = new AbstractProcessor(
new PhenotypeMetaStore(
new TreeMap<>(),
Expand All @@ -77,63 +58,15 @@ public void getPatientSubsetForQuery_oneVariantCategoryFilter_indexFound() {
Query.VariantInfoFilter variantInfoFilter = new Query.VariantInfoFilter();
variantInfoFilter.categoryVariantInfoFilters = categoryVariantInfoFilters;

List<Query.VariantInfoFilter> variantInfoFilters = List.of(variantInfoFilter);

Query query = new Query();
query.setVariantInfoFilters(variantInfoFilters);

Set<Integer> patientSubsetForQuery = abstractProcessor.getPatientSubsetForQuery(query);
assertFalse(patientSubsetForQuery.isEmpty());
assertEquals(argumentCaptor.getValue(), new SparseVariantIndex(Set.of(2,4,6)));
}

@Test
public void getPatientSubsetForQuery_oneVariantCategoryFilterTwoValues_unionFilters() {
//when(variantIndexCache.get(GENE_WITH_VARIANT_KEY, EXAMPLE_GENES_WITH_VARIANT.get(0))).thenReturn(new SparseVariantIndex(Set.of(2, 4)));
//when(variantIndexCache.get(GENE_WITH_VARIANT_KEY, EXAMPLE_GENES_WITH_VARIANT.get(1))).thenReturn(new SparseVariantIndex(Set.of(6)));

ArgumentCaptor<VariantIndex> argumentCaptor = ArgumentCaptor.forClass(VariantIndex.class);
//when(patientVariantJoinHandler.getPatientIdsForIntersectionOfVariantSets(any(), argumentCaptor.capture())).thenReturn(List.of(Set.of(42)));

Map<String, String[]> categoryVariantInfoFilters =
Map.of(GENE_WITH_VARIANT_KEY, new String[] {EXAMPLE_GENES_WITH_VARIANT.get(0), EXAMPLE_GENES_WITH_VARIANT.get(1)});
Query.VariantInfoFilter variantInfoFilter = new Query.VariantInfoFilter();
variantInfoFilter.categoryVariantInfoFilters = categoryVariantInfoFilters;

List<Query.VariantInfoFilter> variantInfoFilters = List.of(variantInfoFilter);

Query query = new Query();
query.setVariantInfoFilters(variantInfoFilters);

Set<Integer> patientSubsetForQuery = abstractProcessor.getPatientSubsetForQuery(query);
assertFalse(patientSubsetForQuery.isEmpty());
// Expected result is the union of the two values
assertEquals(argumentCaptor.getValue(), new SparseVariantIndex(Set.of(2,4,6)));
}

@Test
public void getPatientSubsetForQuery_twoVariantCategoryFilters_intersectFilters() {
//when(variantIndexCache.get(GENE_WITH_VARIANT_KEY, EXAMPLE_GENES_WITH_VARIANT.get(0))).thenReturn(new SparseVariantIndex(Set.of(2, 4, 6)));
//when(variantIndexCache.get(VARIANT_SEVERITY_KEY, EXAMPLE_VARIANT_SEVERITIES.get(0))).thenReturn(new SparseVariantIndex(Set.of(4, 5, 6, 7)));

ArgumentCaptor<VariantIndex> argumentCaptor = ArgumentCaptor.forClass(VariantIndex.class);
//when(patientVariantJoinHandler.getPatientIdsForIntersectionOfVariantSets(any(), argumentCaptor.capture())).thenReturn(List.of(Set.of(42)));

Map<String, String[]> categoryVariantInfoFilters = Map.of(
GENE_WITH_VARIANT_KEY, new String[] {EXAMPLE_GENES_WITH_VARIANT.get(0)},
VARIANT_SEVERITY_KEY, new String[] {EXAMPLE_VARIANT_SEVERITIES.get(0)}
);
Query.VariantInfoFilter variantInfoFilter = new Query.VariantInfoFilter();
variantInfoFilter.categoryVariantInfoFilters = categoryVariantInfoFilters;
when(genomicProcessor.getPatientMask(isA(DistributableQuery.class))).thenReturn(Mono.just(new BigInteger("1100110011")));
when(genomicProcessor.patientMaskToPatientIdSet(eq(new BigInteger("1100110011")))).thenReturn(Set.of(42, 99));

List<Query.VariantInfoFilter> variantInfoFilters = List.of(variantInfoFilter);

Query query = new Query();
query.setVariantInfoFilters(variantInfoFilters);

Set<Integer> patientSubsetForQuery = abstractProcessor.getPatientSubsetForQuery(query);
assertFalse(patientSubsetForQuery.isEmpty());
// Expected result is the intersection of the two filters
assertEquals(argumentCaptor.getValue(), new SparseVariantIndex(Set.of(4, 6)));
assertEquals(Set.of(42, 99), patientSubsetForQuery);
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package edu.harvard.hms.dbmi.avillach.hpds.processing;

import edu.harvard.hms.dbmi.avillach.hpds.data.query.Query;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.test.context.event.annotation.BeforeTestClass;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -18,14 +18,13 @@
@ExtendWith(MockitoExtension.class)
public class CountProcessorTest {

private CountProcessor countProcessor;
private final CountProcessor countProcessor;

@Mock
private AbstractProcessor mockAbstractProcessor;
private final AbstractProcessor mockAbstractProcessor;

@BeforeTestClass
public void before() {
countProcessor = new CountProcessor(mockAbstractProcessor);
public CountProcessorTest(@Mock AbstractProcessor mockAbstractProcessor) {
this.mockAbstractProcessor = mockAbstractProcessor;
this.countProcessor = new CountProcessor(mockAbstractProcessor);
}

@Test
Expand Down
Loading

0 comments on commit 7b4991e

Please sign in to comment.