Skip to content

Commit

Permalink
Balanced io and processing threads
Browse files Browse the repository at this point in the history
  • Loading branch information
genomewalker committed May 8, 2024
1 parent 683d57b commit 0e89557
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 105 deletions.
164 changes: 89 additions & 75 deletions bam_filter/reassign.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
check_tmp_dir_exists,
handle_warning,
create_output_files,
allocate_threads,
)
from multiprocessing import Pool, Manager
from functools import partial
Expand Down Expand Up @@ -240,7 +241,7 @@ def write_to_file(alns, out_bam_file, header=None):

def process_references_batch(references, entries, bam, refs_idx, threads=4):
alns = []
s_threads = min(4, threads)
s_threads = threads
with pysam.AlignmentFile(bam, "rb", threads=s_threads) as samfile:
for reference in references:
r_ids = entries[reference]
Expand Down Expand Up @@ -271,7 +272,8 @@ def write_reassigned_bam(
# else:
# out_bam = out_files["bam_reassigned_sorted"]
out_bam = out_files["bam_reassigned"]
s_threads = min(4, threads)
p_threads, s_threads = allocate_threads(threads, 2, 4)
# s_threads = min(4, threads)
with pysam.AlignmentFile(bam, "rb", threads=s_threads) as samfile:
references = list(entries.keys())
refs_dict = {x: samfile.get_reference_length(x) for x in references}
Expand All @@ -281,7 +283,7 @@ def write_reassigned_bam(
log.info("::: Getting reference names and lengths...")
(ref_names, ref_lengths) = zip(*refs_dict.items())
refs_idx = {sys.intern(str(x)): i for i, x in enumerate(ref_names)}
write_threads = min(4, threads)
write_threads = s_threads

new_header = header.to_dict()

Expand All @@ -293,80 +295,86 @@ def write_reassigned_bam(
new_header["SQ"].sort(key=lambda x: name_index[x["SN"]])
new_header["HD"]["SO"] = "unsorted"

out_bam_file = pysam.AlignmentFile(
with pysam.AlignmentFile(
out_files["bam_reassigned_tmp"],
"wb",
referencenames=list(ref_names),
referencelengths=list(ref_lengths),
threads=write_threads,
header=new_header,
)

# num_cores should be multiple of the write_threads
num_cores = threads // write_threads
# num_cores = min(threads, cpu_count())
# batch_size = len(references) // num_cores + 1 # Ensure non-zero batch size
# batch_size = calc_chunksize(n_workers=num_cores, len_iterable=len(references))
log.info("::: Creating reference chunks with uniform read amounts...")

ref_chunks = sort_keys_by_approx_weight(
input_dict=ref_counts,
scale=1,
num_cores=num_cores,
verbose=False,
max_entries_per_chunk=1_000_000,
)
) as out_bam_file:

# num_cores should be multiple of the write_threads
num_cores = p_threads
# num_cores = min(threads, cpu_count())
# batch_size = len(references) // num_cores + 1 # Ensure non-zero batch size
# batch_size = calc_chunksize(n_workers=num_cores, len_iterable=len(references))
log.info("::: Creating reference chunks with uniform read amounts...")

ref_chunks = sort_keys_by_approx_weight(
input_dict=ref_counts,
scale=1,
num_cores=num_cores,
verbose=False,
max_entries_per_chunk=1_000_000,
)

num_cores = min(num_cores, len(ref_chunks))
log.info(f"::: Using {num_cores} processes to write {len(ref_chunks)} chunk(s)")

with Manager() as manager:
# Use Manager to create a read-only proxy for the dictionary
entries = manager.dict(dict(entries))

with concurrent.futures.ProcessPoolExecutor(max_workers=num_cores) as executor:
# Use ProcessPoolExecutor to parallelize the processing of references in batches
futures = []
for batch_references in tqdm.tqdm(
ref_chunks,
total=len(ref_chunks),
desc="Submitted batches",
unit="batch",
leave=False,
ncols=80,
disable=is_debug(),
):
future = executor.submit(
process_references_batch, batch_references, entries, bam, refs_idx
# num_cores = min(num_cores, len(ref_chunks))
log.info(f"::: Using {num_cores} processes to write {len(ref_chunks)} chunk(s)")

with Manager() as manager:
# Use Manager to create a read-only proxy for the dictionary
entries = manager.dict(dict(entries))

with concurrent.futures.ProcessPoolExecutor(
max_workers=num_cores
) as executor:
# Use ProcessPoolExecutor to parallelize the processing of references in batches
futures = []
for batch_references in tqdm.tqdm(
ref_chunks,
total=len(ref_chunks),
desc="Submitted batches",
unit="batch",
leave=False,
ncols=80,
disable=is_debug(),
):
future = executor.submit(
process_references_batch,
batch_references,
entries,
bam,
refs_idx,
s_threads,
)
futures.append(future) # Store the future

# Use a while loop to continuously check for completed futures
log.info("::: Collecting batches...")

completion_progress_bar = tqdm.tqdm(
total=len(futures),
desc="Completed",
unit="batch",
leave=False,
ncols=80,
disable=is_debug(),
)
futures.append(future) # Store the future

# Use a while loop to continuously check for completed futures
log.info("::: Collecting batches...")

completion_progress_bar = tqdm.tqdm(
total=len(futures),
desc="Completed",
unit="batch",
leave=False,
ncols=80,
disable=is_debug(),
)
completed_count = 0
completed_count = 0

# Use as_completed to iterate over completed futures as they become available
for completed_future in concurrent.futures.as_completed(futures):
alns = completed_future.result()
write_to_file(alns=alns, out_bam_file=out_bam_file, header=header)
# Use as_completed to iterate over completed futures as they become available
for completed_future in concurrent.futures.as_completed(futures):
alns = completed_future.result()
write_to_file(alns=alns, out_bam_file=out_bam_file, header=header)

# Update the progress bar for each completed write
completion_progress_bar.update(1)
completed_count += 1
completed_future.cancel() # Cancel the future to free memory
gc.collect() # Force garbage collection
# Update the progress bar for each completed write
completion_progress_bar.update(1)
completed_count += 1
completed_future.cancel() # Cancel the future to free memory
gc.collect() # Force garbage collection

completion_progress_bar.close()
out_bam_file.close()
completion_progress_bar.close()
entries = None
gc.collect()
# prof.disable()
Expand All @@ -375,13 +383,13 @@ def write_reassigned_bam(
# stats.print_stats(5) # top 10 rows
if not disable_sort:
log.info("::: ::: Sorting BAM file...")
s_threads = min(4, threads)
w_threads = max(4, s_threads)
if sort_by_name:
log.info("::: ::: Sorting by name...")
pysam.sort(
"-n",
"-@",
str(s_threads),
str(w_threads),
"-m",
str(sort_memory),
"-o",
Expand All @@ -391,7 +399,7 @@ def write_reassigned_bam(
else:
pysam.sort(
"-@",
str(s_threads),
str(w_threads),
"-m",
str(sort_memory),
"-o",
Expand All @@ -401,7 +409,6 @@ def write_reassigned_bam(

logging.info("BAM index not found. Indexing...")

s_threads = min(4, threads)
pysam.index(
"-c",
"-@",
Expand Down Expand Up @@ -466,7 +473,7 @@ def get_bam_data(
dt.options.progress.enabled = False
dt.options.progress.clear_on_success = True
dt.options.nthreads = max(1, threads - 1)
s_threads = min(4, threads)
s_threads = threads

with pysam.AlignmentFile(bam, "rb", threads=s_threads) as samfile:
results = []
Expand Down Expand Up @@ -560,16 +567,23 @@ def reassign_reads(
sort_memory="4G",
disable_sort=False,
):

p_threads, s_threads = allocate_threads(threads, 2, 4)
dt.options.progress.enabled = True
dt.options.progress.clear_on_success = True
if threads > 1:
dt.options.nthreads = threads - 1
dt.options.nthreads = p_threads
else:
dt.options.nthreads = 1

log.info("::: Loading BAM file")
save = pysam.set_verbosity(0)
s_threads = min(4, threads)
# we need to separate the threads for reading the BAM and process in parallel.
# At least 2 threads are needed for reading the BAM file and the rest for multiprocessing

# s_threads = min(4, threads)

log.info(f"::: IO Threads: {s_threads} | Processing Threads: {p_threads}")

with pysam.AlignmentFile(bam, "rb", threads=s_threads) as samfile:
references = samfile.references
Expand Down Expand Up @@ -672,7 +686,7 @@ def reassign_reads(
)
else:
p = Pool(
threads,
p_threads,
)
data = list(
tqdm.tqdm(
Expand All @@ -688,7 +702,7 @@ def reassign_reads(
gap_extension_penalty=gap_extension_penalty,
lambda_value=lambda_value,
K_value=K_value,
threads=4,
threads=s_threads,
),
parms,
chunksize=1,
Expand Down
Loading

0 comments on commit 0e89557

Please sign in to comment.