From 0279cab3b2827434ff7a4b8e1c74c046a36a3fde Mon Sep 17 00:00:00 2001 From: Ryan Amari Date: Wed, 25 Oct 2023 11:33:40 -0400 Subject: [PATCH] ALS-4978: Refactor to support variants not found. Still broken tests --- .../hpds/data/genotype/VariantStore.java | 11 ++-- .../hpds/processing/AbstractProcessor.java | 2 +- .../processing/GenomicProcessorNodeImpl.java | 51 ++++++++++++------- .../processing/PatientVariantJoinHandler.java | 21 ++------ .../hpds/processing/QueryProcessor.java | 27 +++++----- .../hpds/processing/VariantListProcessor.java | 3 +- .../hpds/processing/VariantService.java | 2 +- .../PatientVariantJoinHandlerTest.java | 43 ++++++++++++---- 8 files changed, 95 insertions(+), 65 deletions(-) diff --git a/data/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/data/genotype/VariantStore.java b/data/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/data/genotype/VariantStore.java index 6541c9dc..822ea447 100644 --- a/data/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/data/genotype/VariantStore.java +++ b/data/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/data/genotype/VariantStore.java @@ -115,7 +115,7 @@ public String[] getPatientIds() { return patientIds; } - public VariantMasks getMasks(String variant, VariantBucketHolder bucketCache) throws IOException { + public Optional getMasks(String variant, VariantBucketHolder bucketCache) throws IOException { String[] segments = variant.split(","); if (segments.length < 2) { log.error("Less than 2 segments found in this variant : " + variant); @@ -133,11 +133,16 @@ public VariantMasks getMasks(String variant, VariantBucketHolder b && chrOffset == bucketCache.lastChunkOffset) { // TODO : This is a temporary efficiency hack, NOT THREADSAFE!!! } else { - bucketCache.lastValue = variantMaskStorage.get(contig).get(chrOffset); + // todo: don't bother doing a lookup if this node does not have the chromosome specified + FileBackedJsonIndexStorage> indexedStorage = variantMaskStorage.get(contig); + if (indexedStorage == null) { + return Optional.empty(); + } + bucketCache.lastValue = indexedStorage.get(chrOffset); bucketCache.lastContig = contig; bucketCache.lastChunkOffset = chrOffset; } - return bucketCache.lastValue == null ? null : bucketCache.lastValue.get(variant); + return bucketCache.lastValue == null ? Optional.empty() : Optional.of(bucketCache.lastValue.get(variant)); } public String[] getHeaders() { diff --git a/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/AbstractProcessor.java b/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/AbstractProcessor.java index 168bf477..a21949f1 100644 --- a/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/AbstractProcessor.java +++ b/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/AbstractProcessor.java @@ -427,7 +427,7 @@ public String[] getPatientIds() { return genomicProcessor.getPatientIds(); } - public VariantMasks getMasks(String path, VariantBucketHolder variantMasksVariantBucketHolder) { + public Optional getMasks(String path, VariantBucketHolder variantMasksVariantBucketHolder) { return variantService.getMasks(path, variantMasksVariantBucketHolder); } diff --git a/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/GenomicProcessorNodeImpl.java b/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/GenomicProcessorNodeImpl.java index ab84f167..b3dc7235 100644 --- a/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/GenomicProcessorNodeImpl.java +++ b/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/GenomicProcessorNodeImpl.java @@ -90,22 +90,34 @@ public BigInteger getPatientMaskForVariantInfoFilters(DistributableQuery distrib intersectionOfInfoFilters = new SparseVariantIndex(Set.of()); } } - // todo: handle empty getVariantInfoFilters() // add filteredIdSet for patients who have matching variants, heterozygous or homozygous for now. - BigInteger patientMask = patientVariantJoinHandler.getPatientIdsForIntersectionOfVariantSets(distributableQuery.getPatientIds(), intersectionOfInfoFilters); + BigInteger patientMask = null; + if (intersectionOfInfoFilters != null ){ + patientMask = patientVariantJoinHandler.getPatientIdsForIntersectionOfVariantSets(distributableQuery.getPatientIds(), intersectionOfInfoFilters); + } + + VariantBucketHolder variantMasksVariantBucketHolder = new VariantBucketHolder<>(); if (!distributableQuery.getRequiredFields().isEmpty() ) { for (String variantSpec : distributableQuery.getRequiredFields()) { - BigInteger patientsForVariantSpec = getIdSetForVariantSpecCategoryFilter(new String[]{"0/1", "1/1"}, variantSpec, null); - patientMask = patientMask.and(patientsForVariantSpec); + BigInteger patientsForVariantSpec = getIdSetForVariantSpecCategoryFilter(new String[]{"0/1", "1/1"}, variantSpec, variantMasksVariantBucketHolder); + if (patientMask == null) { + patientMask = patientsForVariantSpec; + } else { + patientMask = patientMask.and(patientsForVariantSpec); + } } } if (!distributableQuery.getCategoryFilters().isEmpty()) { for (Map.Entry categoryFilterEntry : distributableQuery.getCategoryFilters().entrySet()) { BigInteger patientsForVariantSpec = getIdSetForVariantSpecCategoryFilter(categoryFilterEntry.getValue(), categoryFilterEntry.getKey(), null); - patientMask = patientMask.and(patientsForVariantSpec); + if (patientMask == null) { + patientMask = patientsForVariantSpec; + } else { + patientMask = patientMask.and(patientsForVariantSpec); + } } } @@ -248,16 +260,17 @@ public Collection processVariantList(Set patientSubsetForQuery, //NC - this is the original variant filtering, which checks the patient mask from each variant against the patient mask from the query if(variantsInScope.size()<100000) { ConcurrentSkipListSet variantsWithPatients = new ConcurrentSkipListSet(); - variantsInScope.parallelStream().forEach((String variantKey)->{ - VariantMasks masks = variantService.getMasks(variantKey, new VariantBucketHolder()); - if ( masks.heterozygousMask != null && masks.heterozygousMask.and(patientMasks).bitCount()>4) { - variantsWithPatients.add(variantKey); - } else if ( masks.homozygousMask != null && masks.homozygousMask.and(patientMasks).bitCount()>4) { - variantsWithPatients.add(variantKey); - } else if ( masks.heterozygousNoCallMask != null && masks.heterozygousNoCallMask.and(patientMasks).bitCount()>4) { - //so heterozygous no calls we want, homozygous no calls we don't - variantsWithPatients.add(variantKey); - } + variantsInScope.parallelStream().forEach(variantKey -> { + variantService.getMasks(variantKey, new VariantBucketHolder<>()).ifPresent(masks -> { + if ( masks.heterozygousMask != null && masks.heterozygousMask.and(patientMasks).bitCount()>4) { + variantsWithPatients.add(variantKey); + } else if ( masks.homozygousMask != null && masks.homozygousMask.and(patientMasks).bitCount()>4) { + variantsWithPatients.add(variantKey); + } else if ( masks.heterozygousNoCallMask != null && masks.heterozygousNoCallMask.and(patientMasks).bitCount()>4) { + //so heterozygous no calls we want, homozygous no calls we don't + variantsWithPatients.add(variantKey); + } + }); }); return variantsWithPatients; }else { @@ -286,10 +299,9 @@ private ArrayList getBitmasksForVariantSpecCategoryFilter(String[] z ArrayList variantBitmasks = new ArrayList<>(); variantName = variantName.replaceAll(",\\d/\\d$", ""); log.debug("looking up mask for : " + variantName); - VariantMasks masks; - masks = variantService.getMasks(variantName, bucketCache); + Optional optionalMasks = variantService.getMasks(variantName, bucketCache); Arrays.stream(zygosities).forEach((zygosity) -> { - if(masks!=null) { + optionalMasks.ifPresent(masks -> { if(zygosity.equals(HOMOZYGOUS_REFERENCE)) { BigInteger homozygousReferenceBitmask = calculateIndiscriminateBitmask(masks); for(int x = 2;x getBitmasksForVariantSpecCategoryFilter(String[] z }else if(zygosity.equals("")) { variantBitmasks.add(calculateIndiscriminateBitmask(masks)); } - } else { + }); + if (optionalMasks.isEmpty()) { variantBitmasks.add(variantService.emptyBitmask()); } diff --git a/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/PatientVariantJoinHandler.java b/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/PatientVariantJoinHandler.java index 73a7a603..346c94bb 100644 --- a/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/PatientVariantJoinHandler.java +++ b/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/PatientVariantJoinHandler.java @@ -66,12 +66,10 @@ public BigInteger getPatientIdsForIntersectionOfVariantSets(Set patient x < variantBucketPartitions.size() && matchingPatients[0].bitCount() < patientsInScopeSize + 4; x++) { List> variantBuckets = variantBucketPartitions.get(x); - variantBuckets.parallelStream().forEach((variantBucket)->{ - VariantBucketHolder bucketCache = new VariantBucketHolder(); - variantBucket.stream().forEach((variantSpec)->{ - VariantMasks masks; - masks = variantService.getMasks(variantSpec, bucketCache); - if(masks != null) { + variantBuckets.parallelStream().forEach(variantBucket -> { + VariantBucketHolder bucketCache = new VariantBucketHolder<>(); + variantBucket.forEach(variantSpec -> { + variantService.getMasks(variantSpec, bucketCache).ifPresent(masks -> { BigInteger heteroMask = masks.heterozygousMask == null ? variantService.emptyBitmask() : masks.heterozygousMask; BigInteger homoMask = masks.homozygousMask == null ? variantService.emptyBitmask() : masks.homozygousMask; BigInteger orMasks = heteroMask.or(homoMask); @@ -79,19 +77,10 @@ public BigInteger getPatientIdsForIntersectionOfVariantSets(Set patient synchronized(matchingPatients) { matchingPatients[0] = matchingPatients[0].or(andMasks); } - } + }); }); }); } -/* Set ids = new TreeSet(); - String bitmaskString = matchingPatients[0].toString(2); - for(int x = 2;x < bitmaskString.length()-2;x++) { - if('1'==bitmaskString.charAt(x)) { - String patientId = variantService.getPatientIds()[x-2].trim(); - ids.add(Integer.parseInt(patientId)); - } - } - return ids;*/ return matchingPatients[0]; }else { log.error("No matches found for info filters."); diff --git a/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/QueryProcessor.java b/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/QueryProcessor.java index 71ffcec8..f2373181 100644 --- a/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/QueryProcessor.java +++ b/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/QueryProcessor.java @@ -105,7 +105,7 @@ private void processColumn(List paths, TreeSet ids, ResultStore Integer x) { String path = paths.get(x-1); if(VariantUtils.pathIsVariantSpec(path)) { - VariantMasks masks = abstractProcessor.getMasks(path, new VariantBucketHolder()); + Optional masks = abstractProcessor.getMasks(path, new VariantBucketHolder<>()); String[] patientIds = abstractProcessor.getPatientIds(); int idPointer = 0; @@ -117,8 +117,7 @@ private void processColumn(List paths, TreeSet ids, ResultStore if(key < id) { idPointer++; } else if(key == id){ - idPointer = writeVariantResultField(results, x, masks, idPointer, doubleBuffer, - idInSubsetPointer); + idPointer = writeVariantResultField(results, x, masks, idPointer, idInSubsetPointer); break; } else { writeVariantNullResultField(results, x, doubleBuffer, idInSubsetPointer); @@ -162,17 +161,17 @@ private void writeVariantNullResultField(ResultStore results, Integer x, ByteBuf results.writeField(x,idInSubsetPointer, valueBuffer); } - private int writeVariantResultField(ResultStore results, Integer x, VariantMasks masks, int idPointer, - ByteBuffer doubleBuffer, int idInSubsetPointer) { - byte[] valueBuffer; - if(masks.heterozygousMask != null && masks.heterozygousMask.testBit(idPointer)) { - valueBuffer = "0/1".getBytes(); - }else if(masks.homozygousMask != null && masks.homozygousMask.testBit(idPointer)) { - valueBuffer = "1/1".getBytes(); - }else { - valueBuffer = "0/0".getBytes(); - } - valueBuffer = masks.toString().getBytes(); + private int writeVariantResultField(ResultStore results, Integer x, Optional variantMasks, int idPointer, + int idInSubsetPointer) { + byte[] valueBuffer = variantMasks.map(masks -> { + if(masks.heterozygousMask != null && masks.heterozygousMask.testBit(idPointer)) { + return "0/1".getBytes(); + } else if(masks.homozygousMask != null && masks.homozygousMask.testBit(idPointer)) { + return "1/1".getBytes(); + }else { + return "0/0".getBytes(); + } + }).orElse("".getBytes()); results.writeField(x,idInSubsetPointer, valueBuffer); return idPointer; } diff --git a/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/VariantListProcessor.java b/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/VariantListProcessor.java index debc142b..925a2394 100644 --- a/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/VariantListProcessor.java +++ b/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/VariantListProcessor.java @@ -261,7 +261,8 @@ public String runVcfExcerptQuery(Query query, boolean includePatientData) throws } } - VariantMasks masks = abstractProcessor.getMasks(variantSpec, variantMaskBucketHolder); + // todo: deal with empty return + VariantMasks masks = abstractProcessor.getMasks(variantSpec, variantMaskBucketHolder).get(); //make strings of 000100 so we can just check 'char at' //so heterozygous no calls we want, homozygous no calls we don't diff --git a/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/VariantService.java b/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/VariantService.java index bdf3c046..166bf1ee 100644 --- a/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/VariantService.java +++ b/processing/src/main/java/edu/harvard/hms/dbmi/avillach/hpds/processing/VariantService.java @@ -188,7 +188,7 @@ public String[] getPatientIds() { return variantStore.getPatientIds(); } - public VariantMasks getMasks(String variantName, VariantBucketHolder bucketCache) { + public Optional getMasks(String variantName, VariantBucketHolder bucketCache) { try { return variantStore.getMasks(variantName, bucketCache); } catch (IOException e) { diff --git a/processing/src/test/java/edu/harvard/hms/dbmi/avillach/hpds/processing/PatientVariantJoinHandlerTest.java b/processing/src/test/java/edu/harvard/hms/dbmi/avillach/hpds/processing/PatientVariantJoinHandlerTest.java index 266f6af0..de1feb1e 100644 --- a/processing/src/test/java/edu/harvard/hms/dbmi/avillach/hpds/processing/PatientVariantJoinHandlerTest.java +++ b/processing/src/test/java/edu/harvard/hms/dbmi/avillach/hpds/processing/PatientVariantJoinHandlerTest.java @@ -9,6 +9,7 @@ import java.math.BigInteger; import java.util.List; +import java.util.Optional; import java.util.Set; import java.util.TreeSet; import java.util.stream.Collectors; @@ -47,9 +48,31 @@ public void getPatientIdsForIntersectionOfVariantSets_allPatientsMatchOneVariant variantMasks.heterozygousMask = maskForAllPatients; VariantMasks emptyVariantMasks = new VariantMasks(new String[0]); emptyVariantMasks.heterozygousMask = maskForNoPatients; - when(variantService.getMasks(eq(VARIANT_INDEX[0]), any())).thenReturn(variantMasks); - when(variantService.getMasks(eq(VARIANT_INDEX[2]), any())).thenReturn(emptyVariantMasks); - when(variantService.getMasks(eq(VARIANT_INDEX[4]), any())).thenReturn(emptyVariantMasks); + when(variantService.getMasks(eq(VARIANT_INDEX[0]), any())).thenReturn(Optional.of(variantMasks)); + when(variantService.getMasks(eq(VARIANT_INDEX[2]), any())).thenReturn(Optional.of(emptyVariantMasks)); + when(variantService.getMasks(eq(VARIANT_INDEX[4]), any())).thenReturn(Optional.of(emptyVariantMasks)); + + Set patientIdsForIntersectionOfVariantSets = patientMaskToPatientIdSet(patientVariantJoinHandler.getPatientIdsForIntersectionOfVariantSets(Set.of(), intersectionOfInfoFilters)); + // this should be all patients, as all patients match one of the variants + assertEquals(PATIENT_IDS_INTEGERS, patientIdsForIntersectionOfVariantSets); + } + + @Test + public void getPatientIdsForIntersectionOfVariantSets_allPatientsMatchOneVariantWithNoVariantFound() { + VariantIndex intersectionOfInfoFilters = new SparseVariantIndex(Set.of(0, 2, 4)); + when(variantService.getPatientIds()).thenReturn(PATIENT_IDS); + when(variantService.emptyBitmask()).thenReturn(emptyBitmask(PATIENT_IDS)); + + BigInteger maskForAllPatients = patientVariantJoinHandler.createMaskForPatientSet(PATIENT_IDS_INTEGERS); + BigInteger maskForNoPatients = patientVariantJoinHandler.createMaskForPatientSet(Set.of()); + + VariantMasks variantMasks = new VariantMasks(new String[0]); + variantMasks.heterozygousMask = maskForAllPatients; + VariantMasks emptyVariantMasks = new VariantMasks(new String[0]); + emptyVariantMasks.heterozygousMask = maskForNoPatients; + when(variantService.getMasks(eq(VARIANT_INDEX[0]), any())).thenReturn(Optional.of(variantMasks)); + when(variantService.getMasks(eq(VARIANT_INDEX[2]), any())).thenReturn(Optional.empty()); + when(variantService.getMasks(eq(VARIANT_INDEX[4]), any())).thenReturn(Optional.empty()); Set patientIdsForIntersectionOfVariantSets = patientMaskToPatientIdSet(patientVariantJoinHandler.getPatientIdsForIntersectionOfVariantSets(Set.of(), intersectionOfInfoFilters)); // this should be all patients, as all patients match one of the variants @@ -65,9 +88,9 @@ public void getPatientIdsForIntersectionOfVariantSets_noPatientsMatchVariants() BigInteger maskForNoPatients = patientVariantJoinHandler.createMaskForPatientSet(Set.of()); VariantMasks emptyVariantMasks = new VariantMasks(new String[0]); emptyVariantMasks.heterozygousMask = maskForNoPatients; - when(variantService.getMasks(eq(VARIANT_INDEX[0]), any())).thenReturn(emptyVariantMasks); - when(variantService.getMasks(eq(VARIANT_INDEX[2]), any())).thenReturn(emptyVariantMasks); - when(variantService.getMasks(eq(VARIANT_INDEX[4]), any())).thenReturn(emptyVariantMasks); + when(variantService.getMasks(eq(VARIANT_INDEX[0]), any())).thenReturn(Optional.of(emptyVariantMasks)); + when(variantService.getMasks(eq(VARIANT_INDEX[2]), any())).thenReturn(Optional.of(emptyVariantMasks)); + when(variantService.getMasks(eq(VARIANT_INDEX[4]), any())).thenReturn(Optional.of(emptyVariantMasks)); Set patientIdsForIntersectionOfVariantSets = patientMaskToPatientIdSet(patientVariantJoinHandler.getPatientIdsForIntersectionOfVariantSets(Set.of(), intersectionOfInfoFilters)); // this should be empty because all variants masks have no matching patients @@ -87,8 +110,8 @@ public void getPatientIdsForIntersectionOfVariantSets_somePatientsMatchVariants( variantMasks.heterozygousMask = maskForPatients1; VariantMasks variantMasks2 = new VariantMasks(new String[0]); variantMasks2.heterozygousMask = maskForPatients2; - when(variantService.getMasks(eq(VARIANT_INDEX[0]), any())).thenReturn(variantMasks); - when(variantService.getMasks(eq(VARIANT_INDEX[2]), any())).thenReturn(variantMasks2); + when(variantService.getMasks(eq(VARIANT_INDEX[0]), any())).thenReturn(Optional.of(variantMasks)); + when(variantService.getMasks(eq(VARIANT_INDEX[2]), any())).thenReturn(Optional.of(variantMasks2)); Set patientIdsForIntersectionOfVariantSets = patientMaskToPatientIdSet(patientVariantJoinHandler.getPatientIdsForIntersectionOfVariantSets(Set.of(), intersectionOfInfoFilters)); // this should be all patients who match at least one variant @@ -116,8 +139,8 @@ public void getPatientIdsForIntersectionOfVariantSets_patientSubsetPassed() { variantMasks.heterozygousMask = maskForPatients1; VariantMasks variantMasks2 = new VariantMasks(new String[0]); variantMasks2.heterozygousMask = maskForPatients2; - when(variantService.getMasks(eq(VARIANT_INDEX[0]), any())).thenReturn(variantMasks); - when(variantService.getMasks(eq(VARIANT_INDEX[2]), any())).thenReturn(variantMasks2); + when(variantService.getMasks(eq(VARIANT_INDEX[0]), any())).thenReturn(Optional.of(variantMasks)); + when(variantService.getMasks(eq(VARIANT_INDEX[2]), any())).thenReturn(Optional.of(variantMasks2)); Set patientIdsForIntersectionOfVariantSets = patientMaskToPatientIdSet(patientVariantJoinHandler.getPatientIdsForIntersectionOfVariantSets(Set.of(102, 103, 104, 105, 106), intersectionOfInfoFilters)); // this should be the union of patients matching variants (101, 103, 105, 107), intersected with the patient subset parameter (103, 104, 105) which is (103, 105)