Skip to content

Commit

Permalink
ALS-4978: Refactor to support variants not found. Still broken tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ramari16 committed Oct 25, 2023
1 parent 9139708 commit 0279cab
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ public String[] getPatientIds() {
return patientIds;
}

public VariantMasks getMasks(String variant, VariantBucketHolder<VariantMasks> bucketCache) throws IOException {
public Optional<VariantMasks> getMasks(String variant, VariantBucketHolder<VariantMasks> bucketCache) throws IOException {
String[] segments = variant.split(",");
if (segments.length < 2) {
log.error("Less than 2 segments found in this variant : " + variant);
Expand All @@ -133,11 +133,16 @@ public VariantMasks getMasks(String variant, VariantBucketHolder<VariantMasks> b
&& chrOffset == bucketCache.lastChunkOffset) {
// TODO : This is a temporary efficiency hack, NOT THREADSAFE!!!
} else {
bucketCache.lastValue = variantMaskStorage.get(contig).get(chrOffset);
// todo: don't bother doing a lookup if this node does not have the chromosome specified
FileBackedJsonIndexStorage<Integer, ConcurrentHashMap<String, VariantMasks>> indexedStorage = variantMaskStorage.get(contig);
if (indexedStorage == null) {
return Optional.empty();
}
bucketCache.lastValue = indexedStorage.get(chrOffset);
bucketCache.lastContig = contig;
bucketCache.lastChunkOffset = chrOffset;
}
return bucketCache.lastValue == null ? null : bucketCache.lastValue.get(variant);
return bucketCache.lastValue == null ? Optional.empty() : Optional.of(bucketCache.lastValue.get(variant));
}

public String[] getHeaders() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ public String[] getPatientIds() {
return genomicProcessor.getPatientIds();
}

public VariantMasks getMasks(String path, VariantBucketHolder<VariantMasks> variantMasksVariantBucketHolder) {
public Optional<VariantMasks> getMasks(String path, VariantBucketHolder<VariantMasks> variantMasksVariantBucketHolder) {
return variantService.getMasks(path, variantMasksVariantBucketHolder);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,22 +90,34 @@ public BigInteger getPatientMaskForVariantInfoFilters(DistributableQuery distrib
intersectionOfInfoFilters = new SparseVariantIndex(Set.of());
}
}

// todo: handle empty getVariantInfoFilters()

// add filteredIdSet for patients who have matching variants, heterozygous or homozygous for now.
BigInteger patientMask = patientVariantJoinHandler.getPatientIdsForIntersectionOfVariantSets(distributableQuery.getPatientIds(), intersectionOfInfoFilters);
BigInteger patientMask = null;
if (intersectionOfInfoFilters != null ){
patientMask = patientVariantJoinHandler.getPatientIdsForIntersectionOfVariantSets(distributableQuery.getPatientIds(), intersectionOfInfoFilters);
}


VariantBucketHolder<VariantMasks> variantMasksVariantBucketHolder = new VariantBucketHolder<>();
if (!distributableQuery.getRequiredFields().isEmpty() ) {
for (String variantSpec : distributableQuery.getRequiredFields()) {
BigInteger patientsForVariantSpec = getIdSetForVariantSpecCategoryFilter(new String[]{"0/1", "1/1"}, variantSpec, null);
patientMask = patientMask.and(patientsForVariantSpec);
BigInteger patientsForVariantSpec = getIdSetForVariantSpecCategoryFilter(new String[]{"0/1", "1/1"}, variantSpec, variantMasksVariantBucketHolder);
if (patientMask == null) {
patientMask = patientsForVariantSpec;
} else {
patientMask = patientMask.and(patientsForVariantSpec);
}
}
}
if (!distributableQuery.getCategoryFilters().isEmpty()) {
for (Map.Entry<String, String[]> categoryFilterEntry : distributableQuery.getCategoryFilters().entrySet()) {
BigInteger patientsForVariantSpec = getIdSetForVariantSpecCategoryFilter(categoryFilterEntry.getValue(), categoryFilterEntry.getKey(), null);
patientMask = patientMask.and(patientsForVariantSpec);
if (patientMask == null) {
patientMask = patientsForVariantSpec;
} else {
patientMask = patientMask.and(patientsForVariantSpec);
}
}
}

Expand Down Expand Up @@ -248,16 +260,17 @@ public Collection<String> processVariantList(Set<Integer> patientSubsetForQuery,
//NC - this is the original variant filtering, which checks the patient mask from each variant against the patient mask from the query
if(variantsInScope.size()<100000) {
ConcurrentSkipListSet<String> variantsWithPatients = new ConcurrentSkipListSet<String>();
variantsInScope.parallelStream().forEach((String variantKey)->{
VariantMasks masks = variantService.getMasks(variantKey, new VariantBucketHolder<VariantMasks>());
if ( masks.heterozygousMask != null && masks.heterozygousMask.and(patientMasks).bitCount()>4) {
variantsWithPatients.add(variantKey);
} else if ( masks.homozygousMask != null && masks.homozygousMask.and(patientMasks).bitCount()>4) {
variantsWithPatients.add(variantKey);
} else if ( masks.heterozygousNoCallMask != null && masks.heterozygousNoCallMask.and(patientMasks).bitCount()>4) {
//so heterozygous no calls we want, homozygous no calls we don't
variantsWithPatients.add(variantKey);
}
variantsInScope.parallelStream().forEach(variantKey -> {
variantService.getMasks(variantKey, new VariantBucketHolder<>()).ifPresent(masks -> {
if ( masks.heterozygousMask != null && masks.heterozygousMask.and(patientMasks).bitCount()>4) {
variantsWithPatients.add(variantKey);
} else if ( masks.homozygousMask != null && masks.homozygousMask.and(patientMasks).bitCount()>4) {
variantsWithPatients.add(variantKey);
} else if ( masks.heterozygousNoCallMask != null && masks.heterozygousNoCallMask.and(patientMasks).bitCount()>4) {
//so heterozygous no calls we want, homozygous no calls we don't
variantsWithPatients.add(variantKey);
}
});
});
return variantsWithPatients;
}else {
Expand Down Expand Up @@ -286,10 +299,9 @@ private ArrayList<BigInteger> getBitmasksForVariantSpecCategoryFilter(String[] z
ArrayList<BigInteger> variantBitmasks = new ArrayList<>();
variantName = variantName.replaceAll(",\\d/\\d$", "");
log.debug("looking up mask for : " + variantName);
VariantMasks masks;
masks = variantService.getMasks(variantName, bucketCache);
Optional<VariantMasks> optionalMasks = variantService.getMasks(variantName, bucketCache);
Arrays.stream(zygosities).forEach((zygosity) -> {
if(masks!=null) {
optionalMasks.ifPresent(masks -> {
if(zygosity.equals(HOMOZYGOUS_REFERENCE)) {
BigInteger homozygousReferenceBitmask = calculateIndiscriminateBitmask(masks);
for(int x = 2;x<homozygousReferenceBitmask.bitLength()-2;x++) {
Expand All @@ -303,7 +315,8 @@ private ArrayList<BigInteger> getBitmasksForVariantSpecCategoryFilter(String[] z
}else if(zygosity.equals("")) {
variantBitmasks.add(calculateIndiscriminateBitmask(masks));
}
} else {
});
if (optionalMasks.isEmpty()) {
variantBitmasks.add(variantService.emptyBitmask());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,32 +66,21 @@ public BigInteger getPatientIdsForIntersectionOfVariantSets(Set<Integer> patient
x < variantBucketPartitions.size() && matchingPatients[0].bitCount() < patientsInScopeSize + 4;
x++) {
List<List<String>> variantBuckets = variantBucketPartitions.get(x);
variantBuckets.parallelStream().forEach((variantBucket)->{
VariantBucketHolder<VariantMasks> bucketCache = new VariantBucketHolder<VariantMasks>();
variantBucket.stream().forEach((variantSpec)->{
VariantMasks masks;
masks = variantService.getMasks(variantSpec, bucketCache);
if(masks != null) {
variantBuckets.parallelStream().forEach(variantBucket -> {
VariantBucketHolder<VariantMasks> bucketCache = new VariantBucketHolder<>();
variantBucket.forEach(variantSpec -> {
variantService.getMasks(variantSpec, bucketCache).ifPresent(masks -> {
BigInteger heteroMask = masks.heterozygousMask == null ? variantService.emptyBitmask() : masks.heterozygousMask;
BigInteger homoMask = masks.homozygousMask == null ? variantService.emptyBitmask() : masks.homozygousMask;
BigInteger orMasks = heteroMask.or(homoMask);
BigInteger andMasks = orMasks.and(patientsInScopeMask);
synchronized(matchingPatients) {
matchingPatients[0] = matchingPatients[0].or(andMasks);
}
}
});
});
});
}
/* Set<Integer> ids = new TreeSet<Integer>();
String bitmaskString = matchingPatients[0].toString(2);
for(int x = 2;x < bitmaskString.length()-2;x++) {
if('1'==bitmaskString.charAt(x)) {
String patientId = variantService.getPatientIds()[x-2].trim();
ids.add(Integer.parseInt(patientId));
}
}
return ids;*/
return matchingPatients[0];
}else {
log.error("No matches found for info filters.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ private void processColumn(List<String> paths, TreeSet<Integer> ids, ResultStore
Integer x) {
String path = paths.get(x-1);
if(VariantUtils.pathIsVariantSpec(path)) {
VariantMasks masks = abstractProcessor.getMasks(path, new VariantBucketHolder<VariantMasks>());
Optional<VariantMasks> masks = abstractProcessor.getMasks(path, new VariantBucketHolder<>());
String[] patientIds = abstractProcessor.getPatientIds();
int idPointer = 0;

Expand All @@ -117,8 +117,7 @@ private void processColumn(List<String> paths, TreeSet<Integer> ids, ResultStore
if(key < id) {
idPointer++;
} else if(key == id){
idPointer = writeVariantResultField(results, x, masks, idPointer, doubleBuffer,
idInSubsetPointer);
idPointer = writeVariantResultField(results, x, masks, idPointer, idInSubsetPointer);
break;
} else {
writeVariantNullResultField(results, x, doubleBuffer, idInSubsetPointer);
Expand Down Expand Up @@ -162,17 +161,17 @@ private void writeVariantNullResultField(ResultStore results, Integer x, ByteBuf
results.writeField(x,idInSubsetPointer, valueBuffer);
}

private int writeVariantResultField(ResultStore results, Integer x, VariantMasks masks, int idPointer,
ByteBuffer doubleBuffer, int idInSubsetPointer) {
byte[] valueBuffer;
if(masks.heterozygousMask != null && masks.heterozygousMask.testBit(idPointer)) {
valueBuffer = "0/1".getBytes();
}else if(masks.homozygousMask != null && masks.homozygousMask.testBit(idPointer)) {
valueBuffer = "1/1".getBytes();
}else {
valueBuffer = "0/0".getBytes();
}
valueBuffer = masks.toString().getBytes();
private int writeVariantResultField(ResultStore results, Integer x, Optional<VariantMasks> variantMasks, int idPointer,
int idInSubsetPointer) {
byte[] valueBuffer = variantMasks.map(masks -> {
if(masks.heterozygousMask != null && masks.heterozygousMask.testBit(idPointer)) {
return "0/1".getBytes();
} else if(masks.homozygousMask != null && masks.homozygousMask.testBit(idPointer)) {
return "1/1".getBytes();
}else {
return "0/0".getBytes();
}
}).orElse("".getBytes());
results.writeField(x,idInSubsetPointer, valueBuffer);
return idPointer;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ public String runVcfExcerptQuery(Query query, boolean includePatientData) throws
}
}

VariantMasks masks = abstractProcessor.getMasks(variantSpec, variantMaskBucketHolder);
// todo: deal with empty return
VariantMasks masks = abstractProcessor.getMasks(variantSpec, variantMaskBucketHolder).get();

//make strings of 000100 so we can just check 'char at'
//so heterozygous no calls we want, homozygous no calls we don't
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ public String[] getPatientIds() {
return variantStore.getPatientIds();
}

public VariantMasks getMasks(String variantName, VariantBucketHolder<VariantMasks> bucketCache) {
public Optional<VariantMasks> getMasks(String variantName, VariantBucketHolder<VariantMasks> bucketCache) {
try {
return variantStore.getMasks(variantName, bucketCache);
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import java.math.BigInteger;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -47,9 +48,31 @@ public void getPatientIdsForIntersectionOfVariantSets_allPatientsMatchOneVariant
variantMasks.heterozygousMask = maskForAllPatients;
VariantMasks emptyVariantMasks = new VariantMasks(new String[0]);
emptyVariantMasks.heterozygousMask = maskForNoPatients;
when(variantService.getMasks(eq(VARIANT_INDEX[0]), any())).thenReturn(variantMasks);
when(variantService.getMasks(eq(VARIANT_INDEX[2]), any())).thenReturn(emptyVariantMasks);
when(variantService.getMasks(eq(VARIANT_INDEX[4]), any())).thenReturn(emptyVariantMasks);
when(variantService.getMasks(eq(VARIANT_INDEX[0]), any())).thenReturn(Optional.of(variantMasks));
when(variantService.getMasks(eq(VARIANT_INDEX[2]), any())).thenReturn(Optional.of(emptyVariantMasks));
when(variantService.getMasks(eq(VARIANT_INDEX[4]), any())).thenReturn(Optional.of(emptyVariantMasks));

Set<Integer> patientIdsForIntersectionOfVariantSets = patientMaskToPatientIdSet(patientVariantJoinHandler.getPatientIdsForIntersectionOfVariantSets(Set.of(), intersectionOfInfoFilters));
// this should be all patients, as all patients match one of the variants
assertEquals(PATIENT_IDS_INTEGERS, patientIdsForIntersectionOfVariantSets);
}

@Test
public void getPatientIdsForIntersectionOfVariantSets_allPatientsMatchOneVariantWithNoVariantFound() {
VariantIndex intersectionOfInfoFilters = new SparseVariantIndex(Set.of(0, 2, 4));
when(variantService.getPatientIds()).thenReturn(PATIENT_IDS);
when(variantService.emptyBitmask()).thenReturn(emptyBitmask(PATIENT_IDS));

BigInteger maskForAllPatients = patientVariantJoinHandler.createMaskForPatientSet(PATIENT_IDS_INTEGERS);
BigInteger maskForNoPatients = patientVariantJoinHandler.createMaskForPatientSet(Set.of());

VariantMasks variantMasks = new VariantMasks(new String[0]);
variantMasks.heterozygousMask = maskForAllPatients;
VariantMasks emptyVariantMasks = new VariantMasks(new String[0]);
emptyVariantMasks.heterozygousMask = maskForNoPatients;
when(variantService.getMasks(eq(VARIANT_INDEX[0]), any())).thenReturn(Optional.of(variantMasks));
when(variantService.getMasks(eq(VARIANT_INDEX[2]), any())).thenReturn(Optional.empty());
when(variantService.getMasks(eq(VARIANT_INDEX[4]), any())).thenReturn(Optional.empty());

Set<Integer> patientIdsForIntersectionOfVariantSets = patientMaskToPatientIdSet(patientVariantJoinHandler.getPatientIdsForIntersectionOfVariantSets(Set.of(), intersectionOfInfoFilters));
// this should be all patients, as all patients match one of the variants
Expand All @@ -65,9 +88,9 @@ public void getPatientIdsForIntersectionOfVariantSets_noPatientsMatchVariants()
BigInteger maskForNoPatients = patientVariantJoinHandler.createMaskForPatientSet(Set.of());
VariantMasks emptyVariantMasks = new VariantMasks(new String[0]);
emptyVariantMasks.heterozygousMask = maskForNoPatients;
when(variantService.getMasks(eq(VARIANT_INDEX[0]), any())).thenReturn(emptyVariantMasks);
when(variantService.getMasks(eq(VARIANT_INDEX[2]), any())).thenReturn(emptyVariantMasks);
when(variantService.getMasks(eq(VARIANT_INDEX[4]), any())).thenReturn(emptyVariantMasks);
when(variantService.getMasks(eq(VARIANT_INDEX[0]), any())).thenReturn(Optional.of(emptyVariantMasks));
when(variantService.getMasks(eq(VARIANT_INDEX[2]), any())).thenReturn(Optional.of(emptyVariantMasks));
when(variantService.getMasks(eq(VARIANT_INDEX[4]), any())).thenReturn(Optional.of(emptyVariantMasks));

Set<Integer> patientIdsForIntersectionOfVariantSets = patientMaskToPatientIdSet(patientVariantJoinHandler.getPatientIdsForIntersectionOfVariantSets(Set.of(), intersectionOfInfoFilters));
// this should be empty because all variants masks have no matching patients
Expand All @@ -87,8 +110,8 @@ public void getPatientIdsForIntersectionOfVariantSets_somePatientsMatchVariants(
variantMasks.heterozygousMask = maskForPatients1;
VariantMasks variantMasks2 = new VariantMasks(new String[0]);
variantMasks2.heterozygousMask = maskForPatients2;
when(variantService.getMasks(eq(VARIANT_INDEX[0]), any())).thenReturn(variantMasks);
when(variantService.getMasks(eq(VARIANT_INDEX[2]), any())).thenReturn(variantMasks2);
when(variantService.getMasks(eq(VARIANT_INDEX[0]), any())).thenReturn(Optional.of(variantMasks));
when(variantService.getMasks(eq(VARIANT_INDEX[2]), any())).thenReturn(Optional.of(variantMasks2));

Set<Integer> patientIdsForIntersectionOfVariantSets = patientMaskToPatientIdSet(patientVariantJoinHandler.getPatientIdsForIntersectionOfVariantSets(Set.of(), intersectionOfInfoFilters));
// this should be all patients who match at least one variant
Expand Down Expand Up @@ -116,8 +139,8 @@ public void getPatientIdsForIntersectionOfVariantSets_patientSubsetPassed() {
variantMasks.heterozygousMask = maskForPatients1;
VariantMasks variantMasks2 = new VariantMasks(new String[0]);
variantMasks2.heterozygousMask = maskForPatients2;
when(variantService.getMasks(eq(VARIANT_INDEX[0]), any())).thenReturn(variantMasks);
when(variantService.getMasks(eq(VARIANT_INDEX[2]), any())).thenReturn(variantMasks2);
when(variantService.getMasks(eq(VARIANT_INDEX[0]), any())).thenReturn(Optional.of(variantMasks));
when(variantService.getMasks(eq(VARIANT_INDEX[2]), any())).thenReturn(Optional.of(variantMasks2));

Set<Integer> patientIdsForIntersectionOfVariantSets = patientMaskToPatientIdSet(patientVariantJoinHandler.getPatientIdsForIntersectionOfVariantSets(Set.of(102, 103, 104, 105, 106), intersectionOfInfoFilters));
// this should be the union of patients matching variants (101, 103, 105, 107), intersected with the patient subset parameter (103, 104, 105) which is (103, 105)
Expand Down

0 comments on commit 0279cab

Please sign in to comment.