diff --git a/anglerfish/anglerfish.py b/anglerfish/anglerfish.py index 006c158..4638ef6 100755 --- a/anglerfish/anglerfish.py +++ b/anglerfish/anglerfish.py @@ -3,6 +3,7 @@ import glob import gzip import logging +import multiprocessing import os import uuid from collections import Counter @@ -27,6 +28,8 @@ def run_demux(args): + multiprocessing.set_start_method("spawn") + if args.debug: log.setLevel(logging.DEBUG) run_uuid = str(uuid.uuid4()) @@ -151,6 +154,7 @@ def run_demux(args): adaptors_sorted[key], fragments, args.max_distance ) + out_pool = [] for k, v in groupby(sorted(matches, key=lambda x: x[3]), key=lambda y: y[3]): # To avoid collisions in fastq filenames, we add the ONT barcode to the sample name fq_prefix = k @@ -179,7 +183,21 @@ def run_demux(args): ) report.add_sample_stat(sample_stat) if not args.skip_demux: - write_demuxedfastq(sample_dict, fastq_path, fq_name) + out_pool.append((sample_dict, fastq_path, fq_name)) + + # Write demuxed fastq files + pool = multiprocessing.Pool(processes=args.threads) + results = [] + for out in out_pool: + log.debug(f" Writing {out[2]}") + spawn = pool.starmap_async(write_demuxedfastq, [out]) + results.append((spawn, out[2])) + pool.close() + pool.join() + for result in results: + log.debug( + f" PID-{result[0].get()}: wrote {result[1]}, size {os.path.getsize(result[1])} bytes" + ) # Top unmatched indexes nomatch_count = Counter([x[3] for x in no_matches]) @@ -224,7 +242,11 @@ def anglerfish(): help="Analysis output folder (default: Current dir)", ) parser.add_argument( - "--threads", "-t", default=4, help="Number of threads to use (default: 4)" + "--threads", + "-t", + default=4, + type=int, + help="Number of threads to use (default: 4)", ) parser.add_argument( "--skip_demux", diff --git a/anglerfish/demux/demux.py b/anglerfish/demux/demux.py index f6220ea..5abb564 100644 --- a/anglerfish/demux/demux.py +++ b/anglerfish/demux/demux.py @@ -233,8 +233,11 @@ def cluster_matches( def write_demuxedfastq(beds, fastq_in, fastq_out): """ + Intended for multiprocessing Take a set of coordinates in bed format [[seq1, start, end, ..][seq2, ..]] from over a set of fastq entries in the input files and do extraction. + + Return: PID of the process """ gz_buf = 131072 fq_files = glob.glob(fastq_in) @@ -263,4 +266,5 @@ def write_demuxedfastq(beds, fastq_in, fastq_out): outfqs += "+\n" outfqs += f"{qual[bed[1] : bed[2]]}\n" oz.stdin.write(outfqs.encode("utf-8")) - log.debug(f" Wrote {fastq_out}, size: {os.path.getsize(fastq_out)} bytes") + + return os.getpid()