diff --git a/src/shorah/b2w.py b/src/shorah/b2w.py index 90f76b6..33c3959 100644 --- a/src/shorah/b2w.py +++ b/src/shorah/b2w.py @@ -14,8 +14,18 @@ def _write_to_file(lines, file_name): with open(file_name, "w") as f: f.writelines("%s\n" % l for l in lines) +def _calc_location_maximum_reads(samfile, reference_name, maximum_reads): + budget = dict() + for pileupcolumn in samfile.pileup(reference_name, multiple_iterators=False): + budget[pileupcolumn.reference_pos] = min( + pileupcolumn.nsegments, + maximum_reads-1 # minus 1 because because maximum_reads is exclusive + ) + + return budget + def _run_one_window(samfile, region_start, reference_name, window_length, - minimum_overlap, maximum_reads, counter): + minimum_overlap, permitted_reads_per_location, counter): arr = [] arr_read_summary = [] @@ -27,14 +37,18 @@ def _run_one_window(samfile, region_start, reference_name, window_length, ) for idx, read in enumerate(iter): - # For loop limited by maximum_reads - # TODO might not be random - #if read_idx > maximum_reads: - # break first_aligned_pos = read.reference_start last_aligned_post = read.reference_end - 1 #reference_end is exclusive + + if permitted_reads_per_location[first_aligned_pos] == 0: + continue + else: + permitted_reads_per_location[first_aligned_pos] = ( + permitted_reads_per_location[first_aligned_pos] - 1 + ) + # 0- vs 1-based correction start_cut_out = region_start - first_aligned_pos - 1 @@ -74,7 +88,7 @@ def _run_one_window(samfile, region_start, reference_name, window_length, f'>{read.query_name} {first_aligned_pos}\n{cut_out_read}' ) - if read.reference_start >= counter and len(full_read) >= minimum_overlap: # TODO does not bind + if read.reference_start >= counter and len(full_read) >= minimum_overlap: arr_read_summary.append( (read.query_name, read.reference_start + 1, read.reference_end, full_read) ) @@ -103,9 +117,10 @@ def build_windows(alignment_file: str, region: str, minimum_overlap: Minimum number of bases to overlap between reference and read to be considered in a window. The rest (i.e. non-overlapping part) will be filled with Ns. - maximum_reads: Upper (inclusive) limit of reads allowed in a window. - Serves to reduce computational load. - minimum_reads: Lower (inclusive) limit of reads allowed in a window. + maximum_reads: Upper (exclusive) limit of reads allowed to start at the + same position in the reference genome. Serves to reduce + computational load. + minimum_reads: Lower (exclusive) limit of reads allowed in a window. Serves to omit windows with low coverage. reference_filename: Path to a FASTA file of the reference sequence. Only necessary if this information is not included in the CRAM file. @@ -125,6 +140,13 @@ def build_windows(alignment_file: str, region: str, counter = 0 tiling = tiling_strategy.get_window_tilings(start, end) print(tiling) + + permitted_reads_per_location = _calc_location_maximum_reads( + samfile, + reference_name, + maximum_reads + ) + for idx, (region_start, window_length) in enumerate(tiling): arr, arr_read_summary, counter = _run_one_window( samfile, @@ -132,7 +154,7 @@ def build_windows(alignment_file: str, region: str, reference_name, window_length, minimum_overlap, - maximum_reads, + dict(permitted_reads_per_location), # copys dict ("pass by value") counter ) region_end = region_start + window_length - 1 @@ -150,16 +172,16 @@ def build_windows(alignment_file: str, region: str, f'{read[0]}\t{tiling[0][0]-1}\t{end_extended_by_a_window}\t{read[1]}\t{read[2]}\t{read[3]}\n' ) - # except last - if len(arr) >= max(minimum_reads, 1) and idx != len(tiling) - 1: + if idx != len(tiling) - 1: # except last _write_to_file(arr, file_name) - line = ( - f'{file_name}\t{reference_name}\t{region_start}\t' - f'{region_end}\t{len(arr)}' - ) - cov_arr.append(line) + if len(arr) > minimum_reads: + line = ( + f'{file_name}\t{reference_name}\t{region_start}\t' + f'{region_end}\t{len(arr)}' + ) + cov_arr.append(line) samfile.close() reads.close() diff --git a/tests/test_b2w.py b/tests/test_b2w.py index 16815d1..4dce95e 100644 --- a/tests/test_b2w.py +++ b/tests/test_b2w.py @@ -17,33 +17,35 @@ def _collect_files(base_path): spec_files.extend(['coverage.txt', 'reads.fas']) return spec_files -@pytest.mark.parametrize("spec_dir,alignment_file,reference_file,region,window_length,overlap_factor,win_min_ext", [ - ("data_1", "test_aln.cram", "test_ref.fasta", "HXB2:2469-3713", 201, 3, 0.85), - ("data_1", "test_aln.cram", "test_ref.fasta", "HXB2:2469-3713", 204, 3, 0.85), - ("data_1", "test_aln.cram", "test_ref.fasta", "HXB2:2469-3713", 200, 4, 0.85), - ("data_1", "test_aln.cram", "test_ref.fasta", "HXB2:2469-3713", 201, 3, 0.75), - ("data_1", "test_aln.cram", "test_ref.fasta", "HXB2:2469-3713", 200, 4, 0.65), - ("data_2", "REF_aln.bam", "cohort_consensus.fasta", "HXB2:2508-3676", 201, 3, 0.85), - ("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:231-276", 201, 3, 0.85), - ("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:843-2770", 201, 3, 0.85), - ("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:3345-3397", 201, 3, 0.85), - ("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:3972-4883", 201, 3, 0.85), - ("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:5764-6739", 201, 3, 0.85), - ("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:7236-9724", 201, 3, 0.85), - ("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:10277-14792", 201, 3, 0.85), - ("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:15372-16075", 201, 3, 0.85), - ("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:16617-19805", 201, 3, 0.85), - ("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:20373-22769", 201, 3, 0.85), - ("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:23323-24655", 201, 3, 0.85), - #("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:25480-29732", 201, 3, 0.85), +# Note: maximum_reads = math.floor(1e5 / window_length) # TODO why divide? +@pytest.mark.parametrize("spec_dir,alignment_file,reference_file,region,window_length,overlap_factor,win_min_ext,maximum_reads,minimum_reads", [ + ("data_1", "test_aln.cram", "test_ref.fasta", "HXB2:2469-3713", 201, 3, 0.85, 497, 0), + ("data_1", "test_aln.cram", "test_ref.fasta", "HXB2:2469-3713", 201, 3, 0.85, 497, 20), + ("data_1", "test_aln.cram", "test_ref.fasta", "HXB2:2469-3713", 201, 3, 0.85, 3, 0), + ("data_1", "test_aln.cram", "test_ref.fasta", "HXB2:2469-3713", 204, 3, 0.85, 490, 0), + ("data_1", "test_aln.cram", "test_ref.fasta", "HXB2:2469-3713", 200, 4, 0.85, 500, 0), + ("data_1", "test_aln.cram", "test_ref.fasta", "HXB2:2469-3713", 201, 3, 0.75, 497, 0), + ("data_1", "test_aln.cram", "test_ref.fasta", "HXB2:2469-3713", 200, 4, 0.65, 500, 0), + ("data_2", "REF_aln.bam", "cohort_consensus.fasta", "HXB2:2508-3676", 201, 3, 0.85, 497, 0), + ("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:231-276", 201, 3, 0.85, 497, 0), + ("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:843-2770", 201, 3, 0.85, 497, 0), + ("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:3345-3397", 201, 3, 0.85, 497, 0), + ("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:3972-4883", 201, 3, 0.85, 497, 0), + ("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:5764-6739", 201, 3, 0.85, 497, 0), + ("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:7236-9724", 201, 3, 0.85, 497, 0), + ("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:10277-14792", 201, 3, 0.85, 497, 0), + ("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:15372-16075", 201, 3, 0.85, 497, 0), + ("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:16617-19805", 201, 3, 0.85, 497, 0), + ("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:20373-22769", 201, 3, 0.85, 497, 0), + ("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:23323-24655", 201, 3, 0.85, 497, 0), + #("data_3", "REF_aln.bam", "NC_045512.2.fasta", "NC_045512.2:25480-29732", 201, 3, 0.85), # TODO incorrect edge ], indirect=["spec_dir"]) -def test_cmp_raw(spec_dir, alignment_file, reference_file, region, window_length,overlap_factor, win_min_ext): +def test_cmp_raw(spec_dir, alignment_file, reference_file, region, window_length,overlap_factor, win_min_ext, maximum_reads, minimum_reads): assert window_length > 0 and window_length%overlap_factor == 0 minimum_overlap = math.floor(window_length * win_min_ext) - maximum_reads = math.floor(1e5 / window_length) # TODO why divide? incr = window_length//overlap_factor - cmd = f"b2w -w {window_length} -i {incr} -m {minimum_overlap} -x {maximum_reads} -c 0 {alignment_file} {reference_file} {region}" + cmd = f"b2w -w {window_length} -i {incr} -m {minimum_overlap} -x {maximum_reads} -c {minimum_reads} {alignment_file} {reference_file} {region}" print(cmd) original = subprocess.run( cmd, @@ -59,7 +61,7 @@ def test_cmp_raw(spec_dir, alignment_file, reference_file, region, window_length tiling_strategy = strategy, minimum_overlap = minimum_overlap, maximum_reads = maximum_reads, - minimum_reads = 0 + minimum_reads = minimum_reads ) spec_files = _collect_files(os.path.join(p, spec_dir)) @@ -87,9 +89,7 @@ def spec_dir(request): created_files = _collect_files(p) print(spec_files) - for file in spec_files: os.remove(os.path.join(p, request.param, file)) for file in created_files: os.remove(os.path.join(p, file)) -