diff --git a/bam_filter/filter.py b/bam_filter/filter.py index 7a1e655..54b19c0 100644 --- a/bam_filter/filter.py +++ b/bam_filter/filter.py @@ -166,6 +166,7 @@ def filter_references(args): if args.min_norm_entropy is None or args.min_norm_gini is None: filter_conditions = { "min_read_length": args.min_read_length, + "max_read_length": args.max_read_length, "min_read_count": args.min_read_count, "min_expected_breadth_ratio": args.min_expected_breadth_ratio, "min_breadth": args.min_breadth, @@ -186,6 +187,7 @@ def filter_references(args): ) filter_conditions = { "min_read_length": args.min_read_length, + "max_read_length": args.max_read_length, "min_read_count": args.min_read_count, "min_expected_breadth_ratio": args.min_expected_breadth_ratio, "min_breadth": args.min_breadth, @@ -197,6 +199,7 @@ def filter_references(args): else: filter_conditions = { "min_read_length": args.min_read_length, + "max_read_length": args.max_read_length, "min_read_count": args.min_read_count, "min_expected_breadth_ratio": args.min_expected_breadth_ratio, "min_breadth": args.min_breadth, @@ -215,6 +218,7 @@ def filter_references(args): ) filter_conditions = { "min_read_length": args.min_read_length, + "max_read_length": args.max_read_length, "min_read_count": args.min_read_count, "min_expected_breadth_ratio": args.min_expected_breadth_ratio, "min_breadth": args.min_breadth, @@ -229,6 +233,7 @@ def filter_references(args): min_norm_gini, min_norm_entropy = args.min_norm_gini, args.min_norm_entropy filter_conditions = { "min_read_length": args.min_read_length, + "max_read_length": args.max_read_length, "min_read_count": args.min_read_count, "min_expected_breadth_ratio": args.min_expected_breadth_ratio, "min_breadth": args.min_breadth, diff --git a/bam_filter/reassign.py b/bam_filter/reassign.py index aed0eaf..cffc901 100644 --- a/bam_filter/reassign.py +++ b/bam_filter/reassign.py @@ -25,213 +25,224 @@ import warnings from bam_filter.sam_utils import check_bam_file import shutil - -# import cProfile as prof -# import pstats +import uuid +import psutil log = logging.getLogger("my_logger") -def initialize_subject_weights(data): +def estimate_array_size(dtype, shape): + """ + Estimate the size of a numpy array in bytes given its dtype and shape. + """ + return np.dtype(dtype).itemsize * np.prod(shape) + + +def initialize_subject_weights(data, mmap_dir=None, max_memory=None): if data.shape[0] > 0: - # Add a new column for inverse sequence length - data["s_W"] = 1 / data["slen"] + total_memory = max_memory if max_memory else psutil.virtual_memory().total + max_source = np.max(data["source"]) + + # Estimate sizes + sum_weights_size = estimate_array_size(np.float64, (max_source + 1,)) + + # Use memory-mapped arrays if estimated size exceeds available memory + if sum_weights_size > total_memory * 0.8: + sum_weights = np.memmap( + os.path.join(mmap_dir, "sum_weights.mmap"), + dtype="float64", + mode="w+", + shape=(max_source + 1,), + ) + else: + sum_weights = np.zeros(max_source + 1, dtype="float64") - # Calculate the sum of weights for each unique source - sum_weights = np.zeros(np.int64(np.max(data["source"])) + 1) + # Calculate the sum of weights for each unique source using np.add.at for efficiency np.add.at(sum_weights, data["source"], data["var"]) - # Calculate the normalized weights based on the sum for each query - query_sum_var = sum_weights[data["source"]] - data["prob"] = data["var"] / query_sum_var + # Calculate the normalized weights directly + data["prob"] = data["var"] / sum_weights[data["source"]] + + del sum_weights return data else: return None -def resolve_multimaps(data, scale=0.9, iters=10): +def resolve_multimaps(data, scale=0.9, iters=10, mmap_dir=None, max_memory=None): + total_memory = max_memory if max_memory else psutil.virtual_memory().total current_iter = 0 + while True: progress_bar = tqdm.tqdm( total=9, desc=f"Iter {current_iter + 1}", unit=" step", - disable=False, # Replace with your logic or a boolean value + disable=False, leave=False, ncols=80, ) log.debug(f"::: Iter: {current_iter + 1} - Getting scores") + # step 1 progress_bar.update(1) n_alns = data.shape[0] log.debug(f"::: Iter: {current_iter + 1} - Total alignment: {n_alns:,}") - # Calculate the weights for each subject log.debug(f"::: Iter: {current_iter + 1} - Calculating weights...") + # step 2 progress_bar.update(1) - subject_weights = np.zeros(np.int64(np.max(data["subject"])) + 1) + max_subject = np.max(data["subject"]) + subject_weights_size = estimate_array_size( + np.float64, (np.int64(max_subject) + 1,) + ) + + if subject_weights_size > total_memory * 0.8: + # Use memory-mapped arrays + subject_weights = np.memmap( + os.path.join(mmap_dir, f"subject_weights_{current_iter}.mmap"), + dtype="float64", + mode="w+", + shape=(np.int64(max_subject) + 1,), + ) + else: + # Use in-memory arrays + subject_weights = np.zeros(np.int64(max_subject) + 1, dtype="float64") + np.add.at(subject_weights, data["subject"], data["prob"]) data["s_W"] = subject_weights[data["subject"]] / data["slen"] - subject_weights = None + del subject_weights log.debug(f"::: Iter: {current_iter + 1} - Calculating probabilities") + # step 3 progress_bar.update(1) - # Calculate the alignment probabilities new_prob = data["prob"] * data["s_W"] - log.debug("Calculating sum of probabilities") - progress_bar.update(1) - prob_sum = data["prob"] * data["s_W"] - prob_sum_array = np.zeros(np.int64(np.max(data["source"])) + 1) - np.add.at(prob_sum_array, data["source"], prob_sum) - prob_sum = None + max_source = np.max(data["source"]) + prob_sum_array_size = estimate_array_size( + np.float64, (np.int64(max_source) + 1,) + ) - # data["prob_sum"] = prob_sum_array[data["source"]] + if prob_sum_array_size > total_memory * 0.8: + # Use memory-mapped arrays + prob_sum_array = np.memmap( + os.path.join(mmap_dir, f"prob_sum_array_{current_iter}.mmap"), + dtype="float64", + mode="w+", + shape=(np.int64(max_source) + 1,), + ) + else: + # Use in-memory arrays + prob_sum_array = np.zeros(np.int64(max_source) + 1, dtype="float64") + + np.add.at(prob_sum_array, data["source"], new_prob) data["prob"] = new_prob / prob_sum_array[data["source"]] - prob_sum_array = None + del new_prob, prob_sum_array log.debug("Calculating query counts") + # step 4 progress_bar.update(1) - # Calculate how many alignments are in each query - # query_counts = np.zeros(np.int64(np.max(data["source"])) + 1) - # np.add.at(query_counts, data["source"], 1) - query_counts = np.bincount(data["source"]) log.debug("Calculating query counts array") + # step 5 progress_bar.update(1) - # Use a separate array for query counts - query_counts_array = np.zeros(np.int64(np.max(data["source"])) + 1) - np.add.at( - query_counts_array, - data["source"], - query_counts[data["source"]], + query_counts_array_size = estimate_array_size( + np.int64, (np.int64(max_source) + 1,) ) + if query_counts_array_size > total_memory * 0.8: + # Use memory-mapped arrays + query_counts_array = np.memmap( + os.path.join(mmap_dir, f"query_counts_array_{current_iter}.mmap"), + dtype="int64", + mode="w+", + shape=(np.int64(max_source) + 1,), + ) + else: + # Use in-memory arrays + query_counts_array = np.zeros(np.int64(max_source) + 1, dtype="int64") + + np.add.at(query_counts_array, data["source"], 1) + log.debug( f"::: Iter: {current_iter + 1} - Calculating number of alignments per query" ) + # step 6 progress_bar.update(1) data["n_aln"] = query_counts_array[data["source"]] - log.debug("Calculating unique alignments") - data["n_aln"] = query_counts_array[data["source"]] - data_unique = data[data["n_aln"] == 1] - n_unique = data_unique.shape[0] + unique_mask = data["n_aln"] == 1 + non_unique_mask = data["n_aln"] > 1 - if n_unique == data.shape[0]: + if np.all(unique_mask): + # step 7 progress_bar.close() log.info("::: ::: No more multimapping reads. Early stopping.") return data - data = data[(data["n_aln"] > 1) & (data["prob"] > 0)] - - # total_n_unique = np.sum(query_counts_array[data["source"]] <= 1) - - query_counts = None - query_counts_array = None log.debug("Calculating max_prob") - # Keep the ones that have a probability higher than the maximum scaled probability - max_prob = np.zeros(np.int64(np.max(data["source"])) + 1) + max_prob_size = estimate_array_size(np.float64, (np.int64(max_source) + 1,)) + + if max_prob_size > total_memory * 0.8: + # Use memory-mapped arrays + max_prob = np.memmap( + os.path.join(mmap_dir, f"max_prob_{current_iter}.mmap"), + dtype="float64", + mode="w+", + shape=(np.int64(max_source) + 1,), + ) + else: + # Use in-memory arrays + max_prob = np.zeros(np.int64(max_source) + 1, dtype="float64") + np.maximum.at(max_prob, data["source"], data["prob"]) + data["max_prob"] = max_prob[data["source"]] * scale + del max_prob - data["max_prob"] = max_prob[data["source"]] - data["max_prob"] = data["max_prob"] * scale - # data["max_prob"] = max_prob[data["source"]] log.debug( f"::: Iter: {current_iter + 1} - Removing alignments with lower probability" ) + # step 8 progress_bar.update(1) to_remove = np.sum(data["prob"] < data["max_prob"]) - data = data[data["prob"] >= data["max_prob"]] - max_prob = None + filter_mask = data["prob"] >= data["max_prob"] + final_mask = non_unique_mask & filter_mask - # Update the iteration count in the function call current_iter += 1 - data["iter"] = current_iter - data_unique["iter"] = current_iter + data["iter"][final_mask] = current_iter - query_counts = np.bincount(data["source"]) - total_n_unique = np.sum(query_counts[data["source"]] <= 1) - - # data_unique["iter"] = current_iter + # Concatenate unique and filtered non-unique data + data = np.concatenate([data[unique_mask], data[final_mask]]) - # data = np.concatenate([data, data_unique]) - data = np.concatenate([data, data_unique]) - data_unique = None + query_counts = np.bincount(data["source"]) + total_n_unique = np.sum(query_counts <= 1) keep_processing = to_remove != 0 log.debug(f"::: Iter: {current_iter} - Removed {to_remove:,} alignments") - log.debug(f"::: Iter: {current_iter} - Total mapping queries: {n_unique:,}") + log.debug( + f"::: Iter: {current_iter} - Total mapping queries: {np.sum(unique_mask):,}" + ) log.debug( f"::: Iter: {current_iter} - New unique mapping queries: {total_n_unique:,}" ) log.debug(f"::: Iter: {current_iter} - Alns left: {data.shape[0]:,}") + # step 9 progress_bar.update(1) progress_bar.close() log.info( - f"::: Iter: {current_iter} - R: {to_remove:,} | U: {total_n_unique:,} | NU: {n_unique:,} | L: {data.shape[0]:,}" + f"::: Iter: {current_iter} - R: {to_remove:,} | U: {np.sum(unique_mask):,} | NU: {total_n_unique:,} | L: {data.shape[0]:,}" ) log.debug(f"::: Iter: {current_iter} - done!") if iters > 0 and current_iter >= iters: log.info("::: ::: Reached maximum iterations. Stopping.") break - elif not keep_processing: + elif to_remove == 0: log.info("::: ::: No more alignments to remove. Stopping.") break - return data - -# def write_reassigned_bam( -# bam, out_files, threads, entries, sort_memory="1G", min_read_ani=90 -# ): -# if out_files["bam_reassigned"] is not None: -# out_bam = out_files["bam_reassigned"] -# else: -# out_bam = out_files["bam_reassigned_sorted"] - -# samfile = pysam.AlignmentFile(bam, "rb", threads=threads) -# references = list(entries.keys()) -# refs_dict = {x: samfile.get_reference_length(x) for x in list(entries.keys())} - -# (ref_names, ref_lengths) = zip(*refs_dict.items()) - -# refs_idx = {sys.intern(str(x)): i for i, x in enumerate(ref_names)} -# if threads > 4: -# write_threads = 4 -# else: -# write_threads = threads - -# out_bam_file = pysam.AlignmentFile( -# out_files["bam_reassigned_tmp"], -# "wb", -# referencenames=list(ref_names), -# referencelengths=list(ref_lengths), -# threads=write_threads, -# ) - -# for reference in tqdm.tqdm( -# references, -# total=len(references), -# leave=False, -# ncols=80, -# desc="References processed", -# ): -# r_ids = entries[reference] -# for aln in samfile.fetch( -# reference=reference, multiple_iterators=False, until_eof=True -# ): -# # ani_read = (1 - ((aln.get_tag("NM") / aln.infer_query_length()))) * 100 -# if (aln.query_name, reference) in r_ids: -# aln.reference_id = refs_idx[aln.reference_name] -# out_bam_file.write(aln) -# out_bam_file.close() - - -# def write_to_file(alns, out_bam_file): -# for aln in tqdm.tqdm(alns, total=len(alns), leave=False, ncols=80, desc="Writing"): -# out_bam_file.write(aln) + return data def write_to_file(alns, out_bam_file, header=None): @@ -239,7 +250,7 @@ def write_to_file(alns, out_bam_file, header=None): out_bam_file.write(pysam.AlignedSegment.fromstring(aln, header)) -def process_references_batch(references, entries, bam, refs_idx, threads=4): +def process_references_batch(references, entries, bam, refs_idx, threads=1): alns = [] s_threads = threads with pysam.AlignmentFile(bam, "rb", threads=s_threads) as samfile: @@ -265,19 +276,14 @@ def write_reassigned_bam( sort_by_name=False, min_read_ani=90, min_read_length=30, + max_read_length=np.Inf, disable_sort=False, ): - # if out_files["bam_reassigned"] is not None: - # out_bam = out_files["bam_reassigned"] - # else: - # out_bam = out_files["bam_reassigned_sorted"] out_bam = out_files["bam_reassigned"] - p_threads, s_threads = allocate_threads(threads, 2, 4) - # s_threads = min(4, threads) + p_threads, s_threads = allocate_threads(threads, 1, 4) with pysam.AlignmentFile(bam, "rb", threads=s_threads) as samfile: references = list(entries.keys()) refs_dict = {x: samfile.get_reference_length(x) for x in references} - # get group reads header = samfile.header log.info("::: Getting reference names and lengths...") @@ -304,11 +310,8 @@ def write_reassigned_bam( header=new_header, ) as out_bam_file: - # num_cores should be multiple of the write_threads num_cores = p_threads - # num_cores = min(threads, cpu_count()) - # batch_size = len(references) // num_cores + 1 # Ensure non-zero batch size - # batch_size = calc_chunksize(n_workers=num_cores, len_iterable=len(references)) + log.info("::: Creating reference chunks with uniform read amounts...") ref_chunks = sort_keys_by_approx_weight( @@ -319,17 +322,14 @@ def write_reassigned_bam( max_entries_per_chunk=1_000_000, ) - # num_cores = min(num_cores, len(ref_chunks)) log.info(f"::: Using {num_cores} processes to write {len(ref_chunks)} chunk(s)") with Manager() as manager: - # Use Manager to create a read-only proxy for the dictionary entries = manager.dict(dict(entries)) with concurrent.futures.ProcessPoolExecutor( max_workers=num_cores ) as executor: - # Use ProcessPoolExecutor to parallelize the processing of references in batches futures = [] for batch_references in tqdm.tqdm( ref_chunks, @@ -348,9 +348,8 @@ def write_reassigned_bam( refs_idx, s_threads, ) - futures.append(future) # Store the future + futures.append(future) - # Use a while loop to continuously check for completed futures log.info("::: Collecting batches...") completion_progress_bar = tqdm.tqdm( @@ -363,24 +362,19 @@ def write_reassigned_bam( ) completed_count = 0 - # Use as_completed to iterate over completed futures as they become available for completed_future in concurrent.futures.as_completed(futures): alns = completed_future.result() write_to_file(alns=alns, out_bam_file=out_bam_file, header=header) - # Update the progress bar for each completed write completion_progress_bar.update(1) completed_count += 1 - completed_future.cancel() # Cancel the future to free memory - gc.collect() # Force garbage collection + completed_future.cancel() + gc.collect() completion_progress_bar.close() entries = None gc.collect() - # prof.disable() - # # print profiling output - # stats = pstats.Stats(prof).strip_dirs().sort_stats("tottime") - # stats.print_stats(5) # top 10 rows + if not disable_sort: log.info("::: ::: Sorting BAM file...") w_threads = max(4, s_threads) @@ -422,10 +416,6 @@ def write_reassigned_bam( shutil.move(out_files["bam_reassigned_tmp"], out_bam) -# Values from: -# https://www.ncbi.nlm.nih.gov/IEB/ToolBox/CPP_DOC/lxr/source/src/algo/blast/core/blast_stat.c - - def calculate_alignment_score( num_matches, num_mismatches, @@ -438,7 +428,6 @@ def calculate_alignment_score( precomputed_factor, # This is lambda_value * match_reward / math.log(2) precomputed_log_K, # This is math.log(K_value) / math.log(2) ): - # Calculate the raw alignment score with reduced arithmetic operations S = ( (num_matches * match_reward) - (num_mismatches * mismatch_penalty) @@ -446,7 +435,6 @@ def calculate_alignment_score( - (gap_extensions * gap_extension_penalty) ) - # Use precomputed factors to calculate the approximate bit score bit_score = precomputed_factor * S - precomputed_log_K return bit_score @@ -457,6 +445,7 @@ def get_bam_data( ref_lengths=None, percid=90, min_read_length=30, + max_read_length=np.Inf, threads=1, match_reward=1, mismatch_penalty=-1, @@ -464,45 +453,37 @@ def get_bam_data( gap_extension_penalty=2, lambda_value=1.02, K_value=0.21, + tmpdir=None, ): - # Precompute factors for the score calculation to avoid redundant computation precomputed_factor = lambda_value * match_reward / math.log(2) precomputed_log_K = math.log(K_value) / math.log(2) bam, references = parms dt.options.progress.enabled = False dt.options.progress.clear_on_success = True - dt.options.nthreads = max(1, threads - 1) - s_threads = threads - with pysam.AlignmentFile(bam, "rb", threads=s_threads) as samfile: - results = [] - reads = set() - refs = set() - empty_df = 0 + results = [] + empty_df = 0 - bam_reference_length = { - reference: np.int64(samfile.get_reference_length(reference)) - for reference in references - } + with pysam.AlignmentFile(bam, "rb", threads=threads) as samfile: + if ref_lengths is None: + reference_lengths = { + reference: np.int64(samfile.get_reference_length(reference)) + for reference in references + } + else: + reference_lengths = { + reference: np.int64(ref_lengths[reference]) for reference in references + } for reference in references: - reference_length = ( - bam_reference_length[reference] - if ref_lengths is None - else np.int64(ref_lengths[reference]) - ) + reference_length = reference_lengths[reference] aln_data = [] - for aln in samfile.fetch( - contig=reference, multiple_iterators=False, until_eof=True - ): - query_length = ( - aln.query_length - if aln.query_length != 0 - else aln.infer_query_length() - ) + fetch = samfile.fetch(reference, multiple_iterators=False, until_eof=True) - if query_length >= min_read_length: + for aln in fetch: + query_length = aln.query_length or aln.infer_query_length() + if query_length >= min_read_length and query_length <= max_read_length: num_mismatches = aln.get_tag("NM") pident = (1 - (num_mismatches / query_length)) * 100 if pident >= percid: @@ -510,18 +491,14 @@ def get_bam_data( num_gaps = aln.get_tag("XO") if aln.has_tag("XO") else 0 gap_extensions = aln.get_tag("XG") if aln.has_tag("XG") else 0 - bit_score = calculate_alignment_score( - num_matches, - num_mismatches, - num_gaps, - gap_extensions, - match_reward, - mismatch_penalty, - gap_open_penalty, - gap_extension_penalty, - precomputed_factor, - precomputed_log_K, + S = ( + (num_matches * match_reward) + - (num_mismatches * mismatch_penalty) + - (num_gaps * gap_open_penalty) + - (gap_extensions * gap_extension_penalty) ) + bit_score = precomputed_factor * S - precomputed_log_K + aln_data.append( ( aln.query_name, @@ -530,8 +507,6 @@ def get_bam_data( reference_length, ) ) - reads.add(aln.query_name) - refs.add(aln.reference_name) if aln_data: aln_data_dt = dt.Frame( @@ -544,8 +519,18 @@ def get_bam_data( else: empty_df += 1 - combined_results = dt.rbind(results) - return (combined_results, reads, refs, empty_df) + if results: + combined_results = dt.rbind(results) + if tmpdir is not None: + uuid_name = str(uuid.uuid4()) + jay_file = os.path.join(tmpdir, f"{uuid_name}.jay") + combined_results.to_jay(jay_file) + del combined_results + return (jay_file, empty_df) + else: + return (combined_results, empty_df) + else: + return (None, empty_df) def reassign_reads( @@ -562,13 +547,16 @@ def reassign_reads( min_read_count=1, min_read_ani=90, min_read_length=30, + max_read_length=np.Inf, reassign_iters=25, reassign_scale=0.9, sort_memory="4G", disable_sort=False, + tmp_dir=None, + max_memory=None, ): - p_threads, s_threads = allocate_threads(threads, 2, 4) + p_threads, s_threads = allocate_threads(threads, 1, 4) dt.options.progress.enabled = True dt.options.progress.clear_on_success = True if threads > 1: @@ -578,10 +566,6 @@ def reassign_reads( log.info("::: Loading BAM file") save = pysam.set_verbosity(0) - # we need to separate the threads for reading the BAM and process in parallel. - # At least 2 threads are needed for reading the BAM file and the rest for multiprocessing - - # s_threads = min(4, threads) log.info(f"::: IO Threads: {s_threads} | Processing Threads: {p_threads}") @@ -596,14 +580,12 @@ def reassign_reads( if reference_lengths is not None: ref_len_dt = dt.fread(reference_lengths) ref_len_dt.names = ["subjectId", "slen"] - # convert to dict ref_len_dict = dict( zip( ref_len_dt["subjectId"].to_list()[0], ref_len_dt["slen"].to_list()[0], ) ) - # check if the dataframe contains all the References in the BAM file if not set(references).issubset(set(ref_len_dict.keys())): logging.error( "The BAM file contains references not found in the reference lengths file" @@ -612,21 +594,22 @@ def reassign_reads( else: ref_len_dict = None - # logging.info(f"Found {samfile.mapped:,} alignments") index_statistics = samfile.get_index_statistics() references_m = { chrom.contig: chrom.mapped for chrom in tqdm.tqdm( [chrom for chrom in index_statistics if chrom.mapped >= min_read_count], - desc="Filtering Chromosomes", + desc="Filtering references", total=len(index_statistics), unit="chrom", leave=False, ncols=80, + unit_scale=True, + unit_divisor=1000, ) } - # get number alignments + del index_statistics n_alns = sum(references_m.values()) log.info(f"::: Kept {n_alns:,} alignments") references = list(references_m.keys()) @@ -639,14 +622,13 @@ def reassign_reads( log.info(f"::: Keeping {len(references):,} references") log.info("::: Creating reference chunks with uniform read amounts...") - # ify the number of chunks ref_chunks = sort_keys_by_approx_weight( input_dict=references_m, scale=1, num_cores=threads, refinement_steps=10, verbose=False, - max_entries_per_chunk=25_000_000, + max_entries_per_chunk=100_000_000, ) log.info(f"::: ::: Created {len(ref_chunks):,} chunks") @@ -654,6 +636,8 @@ def reassign_reads( dt.options.progress.enabled = False dt.options.progress.clear_on_success = True dt.options.nthreads = 1 + del references_m + gc.collect() parms = list(zip([bam] * len(ref_chunks), ref_chunks)) @@ -667,13 +651,15 @@ def reassign_reads( ref_lengths=ref_len_dict, percid=min_read_ani, min_read_length=min_read_length, + max_read_length=max_read_length, match_reward=match_reward, mismatch_penalty=mismatch_penalty, gap_open_penalty=gap_open_penalty, gap_extension_penalty=gap_extension_penalty, lambda_value=lambda_value, K_value=K_value, - threads=4, + threads=s_threads, + tmpdir=out_files["tmp_dir"], ), parms, chunksize=1, @@ -685,9 +671,7 @@ def reassign_reads( ) ) else: - p = Pool( - p_threads, - ) + p = Pool(p_threads) data = list( tqdm.tqdm( p.imap_unordered( @@ -696,6 +680,7 @@ def reassign_reads( ref_lengths=ref_len_dict, percid=min_read_ani, min_read_length=min_read_length, + max_read_length=max_read_length, match_reward=match_reward, mismatch_penalty=mismatch_penalty, gap_open_penalty=gap_open_penalty, @@ -703,6 +688,7 @@ def reassign_reads( lambda_value=lambda_value, K_value=K_value, threads=s_threads, + tmpdir=out_files["tmp_dir"], ), parms, chunksize=1, @@ -725,121 +711,80 @@ def reassign_reads( dt.options.nthreads = 1 log.info("::: Collecting results...") - reads = list() - refs = list() + reads = set() + refs = set() empty_df = 0 - # new_data = list() - # for i in tqdm.tqdm(range(len(data)), total=len(data), leave=False, ncols=80): - # empty_df += data[i][3] - # reads.extend(list(data[i][1])) - # refs.extend(list(data[i][2])) - # data[i] = data[i][0] - new_data = list() - for i in tqdm.tqdm(range(len(data)), total=len(data), leave=False, ncols=80): - empty_df += data[i][3] - reads.extend(list(data[i][1])) - refs.extend(list(data[i][2])) - - # Check if the frame has more than 2 billion rows - if data[i][0].nrows > 2e9: - # Calculate the number of chunks needed - num_chunks = (data[i][0].nrows // 1e9) + (data[i][0].nrows % 1e9 > 0) - log.warning( - f"Frame has more than 2 billion rows. Splitting into {num_chunks:,} chunks..." - ) - chunks = [] + for i in tqdm.tqdm(range(len(data)), total=len(data), leave=False, ncols=80): + empty_df += data[i][1] + df = dt.fread(data[i][0]) + data[i] = df + query_ids = df[:, "queryId"].to_list()[0] + subject_ids = df[:, "subjectId"].to_list()[0] - # Create chunks of 1 billion rows each - for chunk_idx in range(num_chunks): - start = np.int64(chunk_idx * 1e9) - end = np.int64(min((chunk_idx + 1) * 1e9, data[i][0].nrows)) - chunks.append(data[i][0][start:end, :]) + reads.update(query_ids) + refs.update(subject_ids) - # Substitute data[i] with the first chunk - data[i] = chunks[0] + del df - # Append the rest of the chunks to new_data - for chunk in chunks[1:]: - new_data.append(chunk) - else: - # If the frame is not larger than 2 billion rows, keep it as is - data[i] = data[i][0] + reads = list(reads) + refs = list(refs) - # Combine data with new_data - data.extend(new_data) log.info(f"::: ::: Removed {empty_df:,} references without alignments") - # data = dt.rbind([x for x in data]) - - # log.info("::: Indexing references...") - # refs = dt.Frame(list(set(refs))) - # refs.names = ["subjectId"] - # refs["sidx"] = dt.Frame(list(range(refs.shape[0]))) - # refs.key = "subjectId" - - # log.info("::: Indexing reads...") - # reads = dt.Frame(list(set(reads))) - # reads.names = ["queryId"] - # reads["qidx"] = dt.Frame([(i + refs.shape[0]) for i in range(reads.shape[0])]) - # reads.key = "queryId" - - # log.info("::: Combining data...") - # data = data[:, :, dt.join(reads)] - # data = data[:, :, dt.join(refs)] - - # Initialize refs DataFrame log.info("::: Indexing references...") refs = dt.Frame(list(set(refs))) refs.names = ["subjectId"] refs["sidx"] = dt.Frame(list(range(refs.shape[0]))) refs.key = "subjectId" - # Initialize reads DataFrame log.info("::: Indexing reads...") reads = dt.Frame(list(set(reads))) reads.names = ["queryId"] reads["qidx"] = dt.Frame([(i + refs.shape[0]) for i in range(reads.shape[0])]) reads.key = "queryId" - n_alns_0 = 0 - # Loop through each DataFrame in the list and update it with the joined version - log.info("::: Combining data...") - # for i, x in tqdm.tqdm( - # enumerate(data), - # total=len(data), - # desc="Processing batches", - # unit="batch", - # disable=is_debug(), - # leave=False, - # ncols=80, - # ): - # # Perform join with reads and then refs - # x = x[:, :, dt.join(reads)] - # x = x[:, :, dt.join(refs)] - # n_alns_0 += x.shape[0] - # del x["queryId"] - # del x["subjectId"] - # x = x[:, [dt.f.qidx, dt.f.sidx, dt.f.var, dt.f.slen]].to_numpy() - # # Substitute the original DataFrame with the joined version in the list - # data[i] = x - - # Calculate the total number of rows in advance - total_rows = sum( - x.shape[0] for x in data - ) # This assumes `data` is a list of DataFrames/NumPy arrays - - # Assuming all `x` arrays have the same number of columns after processing, use the first one to determine this - # IMPORTANT: This line needs to be executed before the loop and assumes all `x` arrays are similar after processing - num_columns = 4 # Adjust based on your actual data structure - - # Preallocate the NumPy array - mat = np.empty( - (total_rows, num_columns), dtype=np.float64 - ) # Adjust dtype as necessary + log.info("::: Allocating data...") + total_rows = np.int64(sum(x.shape[0] for x in data)) + + n_reads_0 = reads.shape[0] + n_refs_0 = refs.shape[0] + + n_alns_0 = 0 current_index = 0 - for i, x in tqdm.tqdm( - enumerate(data), + + dtype = np.dtype( + [ + ("source", "int64"), + ("subject", "int64"), + ("var", "float32"), + ("slen", "int64"), + ("s_W", "float32"), + ("prob", "float32"), + ("iter", "int32"), + ("n_aln", "int64"), + ("max_prob", "float32"), + ] + ) + + total_memory = max_memory if max_memory else psutil.virtual_memory().total + array_size = estimate_array_size(dtype, (total_rows,)) + + if array_size > total_memory * 0.8: + log.warning("::: Using memory-mapped arrays") + # Use memory-mapped arrays + m = np.memmap( + os.path.join(tmp_dir.name, "m.mmap"), + dtype=dtype, + mode="w+", + shape=(total_rows,), + ) + else: + # Use in-memory arrays + m = np.zeros(total_rows, dtype=dtype) + + for i in tqdm.tqdm( + range(len(data)), total=len(data), desc="Processing batches", unit="batch", @@ -847,114 +792,44 @@ def reassign_reads( leave=False, ncols=80, ): - # Perform join with reads and then refs + x = data.pop(0) if x.shape[0] > 0: x = x[:, :, dt.join(reads)] x = x[:, :, dt.join(refs)] n_alns_0 += x.shape[0] - # Process `x` as before, but directly update `mat` x_processed = x[ :, [dt.f.qidx, dt.f.sidx, dt.f.bitScore, dt.f.slen] ].to_numpy() num_rows = x_processed.shape[0] - # Fill the preallocated array - mat[current_index : current_index + num_rows, :] = x_processed - - # Update the current index + m["source"][current_index : current_index + num_rows] = x_processed[:, 0] + m["subject"][current_index : current_index + num_rows] = x_processed[:, 1] + m["var"][current_index : current_index + num_rows] = x_processed[:, 2] + m["slen"][current_index : current_index + num_rows] = x_processed[:, 3] current_index += num_rows + del x - # After the loop, `mat` is already the concatenated array, so there's no need for further concatenation or conversion. - data = None # Free the memory if `data` is no longer needed + del data + gc.collect() - # Log the final stats - n_reads_0 = reads.shape[0] - n_refs_0 = refs.shape[0] log.info( f"::: References: {n_refs_0:,} | Reads: {n_reads_0:,} | Alignments: {n_alns_0:,}" ) - - # After the loop, use dt.rbind() to combine all the DataFrames in the list - # data = dt.rbind([x for x in data]) - - # del data["queryId"] - # del data["subjectId"] - # n_alns_0 = data.shape[0] - # n_reads_0 = reads.shape[0] - # n_refs_0 = refs.shape[0] - # log.info( - # f"::: References: {n_refs_0:,} | Reads: {n_reads_0:,} | Alignments: {n_alns_0:,}" - # ) - - # log.info("::: Allocating data...") - # # # mat = data[:, [dt.f.qidx, dt.f.sidx, dt.f.var, dt.f.slen]].to_numpy() - # # mat = np.vstack(data) - # # data = None - - # # Create a zeros array with the same number of rows as 'mat' - # zeros_array = np.zeros((mat.shape[0], 5)) - - # # Stack the zeros_array with the original 'mat' - # m = np.column_stack([mat, zeros_array]) - # zeros_array = None - - # dtype = np.dtype( - # [ - # ("source", "int"), - # ("subject", "int"), - # ("var", "float"), - # ("slen", "int"), - # ("s_W", "float"), - # ("prob", "float"), - # ("iter", "int"), - # ("n_aln", "int"), - # ("max_prob", "float"), - # ] - # ) - - # # Convert the unstructured array to structured array - # m = rf.unstructured_to_structured(m, dtype) - # gc.collect() - - log.info("::: Allocating data...") - - # Define the dtype for the structured array - dtype = np.dtype( - [ - ("source", "int64"), - ("subject", "int64"), - ("var", "float64"), - ("slen", "int64"), - ( - "s_W", - "float", - ), # This and following fields are initialized to 0 or a default value - ("prob", "float64"), - ("iter", "int64"), - ("n_aln", "int64"), - ("max_prob", "float64"), - ] - ) - - # Initialize the structured array with zeros directly - m = np.zeros(mat.shape[0], dtype=dtype) - m["source"] = mat[:, 0] - m["subject"] = mat[:, 1] - m["var"] = mat[:, 2] - m["slen"] = mat[:, 3] - - # Force a garbage collection to free up memory from any intermediate arrays that are no longer needed - gc.collect() - log.info("::: Initializing data structures...") - init_data = initialize_subject_weights(m) + init_data = initialize_subject_weights( + m, mmap_dir=tmp_dir.name, max_memory=total_memory + ) if reassign_iters > 0: log.info(f"::: Reassigning reads with {reassign_iters} iterations") else: log.info("::: Reassigning reads until convergence") no_multimaps = resolve_multimaps( - init_data, iters=reassign_iters, scale=reassign_scale + init_data, + iters=reassign_iters, + scale=reassign_scale, + mmap_dir=tmp_dir.name, + max_memory=total_memory, ) n_reads = len(list(set(no_multimaps["source"]))) @@ -968,12 +843,6 @@ def reassign_reads( f'::: Unique mapping reads: {no_multimaps[no_multimaps["n_aln"] == 1].shape[0]:,} | Multimapping reads: {len(np.unique(no_multimaps[no_multimaps["n_aln"] > 1]["source"])):,}' ) - # add this to the array - # no_multimaps["n_aln"] = subject_counts_array[no_multimaps["subject"]] - - # log.info(f"::: Removing references with less than {min_read_count} reads...") - # no_multimaps = no_multimaps[no_multimaps["n_aln"] >= min_read_count] - # log.info(f"{no_multimaps.shape[0]:,} alignments left") log.info("::: Mapping back indices...") if threads > 1: dt.options.nthreads = threads - 1 @@ -991,7 +860,6 @@ def reassign_reads( s = g[:, :, dt.join(refs)] log.info("::: Calculating reads per subject...") - # count how many alignments are in each subjectId s_c = s[:, dt.count(dt.f.subjectId), dt.by(dt.f.subjectId)] s_c.names = ["subjectId", "counts"] references_m = dict() @@ -1005,7 +873,7 @@ def reassign_reads( log.warning("::: No reference sequences with alignments found in the BAM file") create_empty_output_files(out_files) sys.exit(0) - # convert columns queryId from q and subjectId from s to a tuple + log.info("::: Creating filtered set...") entries = defaultdict(set) q_query_ids = q[:, "queryId"].to_list()[0] @@ -1014,12 +882,16 @@ def reassign_reads( for query_id, subject_id in zip(q_query_ids, s_subject_ids): if subject_id in references_m: entries[subject_id].add((query_id, subject_id)) - no_multimaps = None - q = None - s = None - q_query_ids = None - s_subject_ids = None + + del m + del init_data + del no_multimaps + del q + del s + del q_query_ids + del s_subject_ids gc.collect() + log.info("::: Writing to BAM file...") write_reassigned_bam( bam=bam, @@ -1041,6 +913,9 @@ def reassign(args): ) args = get_arguments() + if args.max_read_length < args.min_read_length: + logging.error("Maximum read length cannot be less than minimum read length") + sys.exit(1) bam = args.bam tmp_dir = check_tmp_dir_exists(args.tmp_dir) log.info("Temporary directory: %s", tmp_dir.name) @@ -1079,8 +954,10 @@ def reassign(args): min_read_count=args.min_read_count, min_read_ani=args.min_read_ani, min_read_length=args.min_read_length, + max_read_length=args.max_read_length, reassign_iters=args.reassign_iters, reassign_scale=args.reassign_scale, + max_memory=args.max_memory, sort_memory=args.sort_memory, out_files=out_files, match_reward=args.match_reward, @@ -1090,11 +967,10 @@ def reassign(args): lambda_value=args.lambda_value, K_value=args.K_value, disable_sort=args.disable_sort, + tmp_dir=tmp_dir, ) - # check if sorted BAM file exists, if yes remove it if os.path.exists(sorted_bam): os.remove(sorted_bam) - # check if sorted BAM index file exists, if yes remove it if os.path.exists(sorted_bam + ".bai"): os.remove(sorted_bam + ".bai") elif os.path.exists(sorted_bam + ".csi"): diff --git a/bam_filter/sam_utils.py b/bam_filter/sam_utils.py index 92585c4..4c67c83 100644 --- a/bam_filter/sam_utils.py +++ b/bam_filter/sam_utils.py @@ -792,15 +792,10 @@ def check_bam_file( logging.info("Checking BAM file status") save = pysam.set_verbosity(0) - # Use a with statement to ensure proper closing of the samfile - try: - # p_threads, s_threads = allocate_threads(threads, 2, 4) - s_threads = min(threads, 4) + def process_bam(bam, s_threads): with pysam.AlignmentFile(bam, "rb", threads=s_threads) as samfile: references = samfile.references log.info(f"::: Found {samfile.nreferences:,} reference sequences") - references = samfile.references - pysam.set_verbosity(save) ref_lengths = None if reference_lengths is not None: @@ -810,15 +805,15 @@ def check_bam_file( index_col=0, names=["reference", "length"], ) - # check if the dataframe contains all the References in the BAM file if not set(references).issubset(set(ref_lengths.index)): logging.error( "::: The BAM file contains references not found in the reference lengths file" ) sys.exit(1) - # max_chr_length = np.max(ref_lengths["length"].tolist()) - # Check if BAM files is not sorted by coordinates, sort it by coordinates + del references + gc.collect() + if samfile.header["HD"]["SO"] != "coordinate": log.info("::: BAM file is not sorted by coordinates, sorting it...") sorted_bam = bam.replace(".bam", ".bf-sorted.bam") @@ -827,18 +822,26 @@ def check_bam_file( ) bam = sorted_bam pysam.index("-c", "-@", str(threads), bam) - - samfile = pysam.AlignmentFile(bam, "rb", threads=s_threads) + return bam, True # Indicate that the BAM file was sorted and reopened if not samfile.has_index(): logging.info("::: BAM index not found. Indexing...") - # if max_chr_length > 536870912: - # logging.info("A reference is longer than 2^29") pysam.index("-c", "-@", str(threads), bam) logging.info("::: BAM file looks good.") + return bam, False # Indicate that the BAM file was not sorted and reopened + + try: + s_threads = min(threads, 4) + bam, reopened = process_bam(bam, s_threads) + + # If the BAM file was sorted and reopened, check it again + if reopened: + bam, _ = process_bam(bam, s_threads) + + pysam.set_verbosity(save) + return bam - return bam # No need to reload the samfile after creating index, thanks to the with statement except ValueError as ve: if "file has no sequences defined (mode='rb')" in str(ve): return None @@ -1076,6 +1079,7 @@ def filter_reference_BAM( logging.info( f"::: min_read_count >= {filter_conditions['min_read_count']} " f"& min_read_length >= {filter_conditions['min_read_length']} " + f"& max_read_length <= {filter_conditions['max_read_length']}" f"& min_avg_read_ani >= {filter_conditions['min_avg_read_ani']} " f"& min_expected_breadth_ratio >= {filter_conditions['min_expected_breadth_ratio']} " f"& min_breadth >= {filter_conditions['min_breadth']} " @@ -1096,6 +1100,7 @@ def filter_reference_BAM( df_filtered = df.loc[ (df["n_reads"] >= filter_conditions["min_read_count"]) & (df["read_length_mean"] >= filter_conditions["min_read_length"]) + & (df["read_length_mean"] <= filter_conditions["max_read_length"]) & (df["read_ani_mean"] >= filter_conditions["min_avg_read_ani"]) & ( df["breadth_exp_ratio"] @@ -1112,6 +1117,7 @@ def filter_reference_BAM( logging.info( f"::: min_read_count >= {filter_conditions['min_read_count']} " f"& min_read_length >= {filter_conditions['min_read_length']} " + f"& max_read_length <= {filter_conditions['max_read_length']}" f"& min_avg_read_ani >= {filter_conditions['min_avg_read_ani']} " f"& min_expected_breadth_ratio >= {filter_conditions['min_expected_breadth_ratio']} " f"& min_breadth >= {filter_conditions['min_breadth']} " @@ -1130,6 +1136,7 @@ def filter_reference_BAM( df_filtered = df.loc[ (df["n_reads"] >= filter_conditions["min_read_count"]) & (df["read_length_mean"] >= filter_conditions["min_read_length"]) + & (df["read_length_mean"] <= filter_conditions["max_read_length"]) & (df["read_ani_mean"] >= filter_conditions["min_avg_read_ani"]) & ( df["breadth_exp_ratio"] diff --git a/bam_filter/utils.py b/bam_filter/utils.py index c21f69d..731c690 100644 --- a/bam_filter/utils.py +++ b/bam_filter/utils.py @@ -232,6 +232,8 @@ def sort_keys_by_approx_weight( unit_scale=True, unit_divisor=500, leave=False, + ncols=80, + unit=" keys", ): min_chunk_idx = chunk_weights.index(min(chunk_weights)) chunks[min_chunk_idx].append(key) @@ -343,24 +345,32 @@ def is_integer(n): # function to check if the input value has K, M or G suffix in it + + def check_suffix(val, parser, var): if var == "--scale": units = ["K", "M"] else: units = ["K", "M", "G"] + unit = val[-1] - value = int(val[:-1]) + value = val[:-1] - if is_integer(value) & (unit in units) & (value > 0): + if is_integer(value) and (unit in units) and (int(value) > 0): + value = int(value) if var == "--scale": if unit == "K": val = value * 1000 elif unit == "M": val = value * 1000000 - elif unit == "G": - val = value * 1000000000 - return val + return str(val) else: + if unit == "K": + val = value * 1024 + elif unit == "M": + val = value * 1024 * 1024 + elif unit == "G": + val = value * 1024 * 1024 * 1024 return val else: parser.error( @@ -369,6 +379,26 @@ def check_suffix(val, parser, var): ) +# Example usage with argparse +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="Process some integers.") + parser.add_argument( + "--scale", + type=lambda x: check_suffix(x, parser, "--scale"), + help="Scale value with K or M suffix", + ) + parser.add_argument( + "--memory", + type=lambda x: check_suffix(x, parser, "--memory"), + help="Memory value with K, M, or G suffix", + ) + + args = parser.parse_args() + print("Scale:", args.scale) + print("Memory:", args.memory) + + def get_compression_type(filename): """ Attempts to guess the compression (if any) on a file using the first few bytes. @@ -483,6 +513,7 @@ def check_lca_ranks(val, parser, var): defaults = { "min_read_length": 30, + "max_read_length": np.Inf, "min_read_count": 3, "min_expected_breadth_ratio": 0, "min_norm_entropy": 0, @@ -524,6 +555,7 @@ def check_lca_ranks(val, parser, var): "threads": "Number of threads to use", "prefix": "Prefix used for the output files", "min_read_length": "Minimum read length", + "max_read_length": "Maximum read length", "min_read_count": "Minimum read count", "trim_ends": "Exclude n bases at the ends of the reference sequences", "trim_min": "Remove coverage that are below this percentile. Used for the Truncated Average Depth (TAD) calculation", @@ -576,6 +608,7 @@ def check_lca_ranks(val, parser, var): "lca_stats": "A TSV file from the filter subcommand", "custom": "Use custom taxdump files", "version": "Print program version", + "max_memory": "Maximum memory to use for the EM algorithm", } @@ -750,6 +783,19 @@ def get_arguments(argv=None): dest="min_read_length", help=help_msg["min_read_length"], ) + reassign_optional_args.add_argument( + "-L", + "--max-read-length", + type=lambda x: int( + check_values( + x, minval=1, maxval=np.Inf, parser=parser, var="--max-read-length" + ) + ), + default=defaults["max_read_length"], + metavar="INT", + dest="max_read_length", + help=help_msg["max_read_length"], + ) reassign_optional_args.add_argument( "-n", "--min-read-count", @@ -851,6 +897,15 @@ def get_arguments(argv=None): dest="sort_memory", help=help_msg["sort_memory"], ) + reassign_optional_args.add_argument( + "-M", + "--max-memory", + type=lambda x: check_suffix(x, parser=parser, var="--max-memory"), + default=None, + metavar="INT", + dest="max_memory", + help=help_msg["max_memory"], + ) reassign_optional_args.add_argument( "-N", "--sort-by-name", @@ -928,6 +983,19 @@ def get_arguments(argv=None): dest="min_read_length", help=help_msg["min_read_length"], ) + filtering_filt_args.add_argument( + "-L", + "--max-read-length", + type=lambda x: int( + check_values( + x, minval=1, maxval=np.Inf, parser=parser, var="--max-read-length" + ) + ), + default=defaults["max_read_length"], + metavar="INT", + dest="max_read_length", + help=help_msg["max_read_length"], + ) filtering_filt_args.add_argument( "-n", "--min-read-count", @@ -1451,6 +1519,7 @@ def create_output_files( else: log.error("Mode not recognized") exit(1) + out_files["tmp_dir"] = tmp_dir return out_files # out_files = { diff --git a/setup.py b/setup.py index 987f638..c5b7c10 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,7 @@ "taxopy>=0.12.0", "python-datatable>=1.1.3", "networkx>=3.2.1", + "psutil>=5.9.8", ] setup(