diff --git a/src/cutadapt/cli.py b/src/cutadapt/cli.py index 152ba6a5..1ba357d7 100644 --- a/src/cutadapt/cli.py +++ b/src/cutadapt/cli.py @@ -94,7 +94,7 @@ ) from cutadapt.report import full_report, minimal_report, Statistics from cutadapt.pipeline import SingleEndPipeline, PairedEndPipeline -from cutadapt.runners import run_pipeline +from cutadapt.runners import make_runner from cutadapt.files import InputPaths, OutputFiles, FileOpener from cutadapt.steps import ( InfoFileWriter, @@ -1128,26 +1128,27 @@ def main(cmdlineargs, default_outfile=sys.stdout.buffer) -> Statistics: adapters, adapters2 = adapters_from_args(args) log_adapters(adapters, adapters2 if paired else None) - adapter_names: List[Optional[str]] = [a.name for a in adapters] - adapter_names2: List[Optional[str]] = [a.name for a in adapters2] - outfiles = open_output_files( - args, - default_outfile, - file_opener, - adapter_names, - adapter_names2, - proxied=cores > 1, - ) - pipeline = make_pipeline_from_args(args, outfiles, paired, adapters, adapters2) - logger.info( - "Processing %s reads on %d core%s ...", - {False: "single-end", True: "paired-end"}[pipeline.paired], - cores, - "s" if cores > 1 else "", - ) - stats = run_pipeline( - pipeline, input_paths, outfiles, cores, progress, args.buffer_size - ) + with make_runner(input_paths, cores, args.buffer_size) as runner: + adapter_names: List[Optional[str]] = [a.name for a in adapters] + adapter_names2: List[Optional[str]] = [a.name for a in adapters2] + outfiles = open_output_files( + args, + default_outfile, + file_opener, + adapter_names, + adapter_names2, + proxied=cores > 1, + ) + pipeline = make_pipeline_from_args( + args, outfiles, paired, adapters, adapters2 + ) + logger.info( + "Processing %s reads on %d core%s ...", + {False: "single-end", True: "paired-end"}[pipeline.paired], + cores, + "s" if cores > 1 else "", + ) + stats = runner.run(pipeline, outfiles, progress) except KeyboardInterrupt: if args.debug: raise diff --git a/src/cutadapt/runners.py b/src/cutadapt/runners.py index 6d379de1..9bfa8e3f 100644 --- a/src/cutadapt/runners.py +++ b/src/cutadapt/runners.py @@ -7,14 +7,14 @@ from abc import ABC, abstractmethod from contextlib import ExitStack from multiprocessing.connection import Connection -from typing import Any, List, Optional, Tuple, Sequence, Iterator, TYPE_CHECKING, Union +from typing import Any, List, Optional, Tuple, Sequence, Iterator, TYPE_CHECKING import dnaio from cutadapt.files import InputFiles, OutputFiles, InputPaths, xopen_rb_raise_limit from cutadapt.pipeline import Pipeline from cutadapt.report import Statistics -from cutadapt.utils import Progress, DummyProgress +from cutadapt.utils import Progress logger = logging.getLogger() @@ -235,8 +235,11 @@ class PipelineRunner(ABC): """ @abstractmethod - def run(self, pipeline, progress) -> Statistics: - pass + def run(self, pipeline, outfiles: OutputFiles, progress: Progress) -> Statistics: + """ + progress: Use an object that supports .update() and .close() such + as DummyProgress, cutadapt.utils.Progress or a tqdm instance + """ @abstractmethod def close(self): @@ -275,14 +278,12 @@ class ParallelPipelineRunner(PipelineRunner): def __init__( self, inpaths: InputPaths, - outfiles: OutputFiles, n_workers: int, buffer_size: Optional[int] = None, ): self._n_workers = n_workers self._need_work_queue: multiprocessing.Queue = mpctx.Queue() self._buffer_size = 4 * 1024**2 if buffer_size is None else buffer_size - self._outfiles = outfiles self._inpaths = inpaths # the workers read from these connections connections = [mpctx.Pipe(duplex=False) for _ in range(self._n_workers)] @@ -303,7 +304,9 @@ def __init__( self._reader_process.daemon = True self._reader_process.start() - def _start_workers(self, pipeline) -> Tuple[List[WorkerProcess], List[Connection]]: + def _start_workers( + self, pipeline, outfiles + ) -> Tuple[List[WorkerProcess], List[Connection]]: workers = [] connections = [] for index in range(self._n_workers): @@ -313,7 +316,7 @@ def _start_workers(self, pipeline) -> Tuple[List[WorkerProcess], List[Connection index, pipeline, self._inpaths, - self._outfiles, + outfiles, self._connections[index], conn_w, self._need_work_queue, @@ -323,10 +326,10 @@ def _start_workers(self, pipeline) -> Tuple[List[WorkerProcess], List[Connection workers.append(worker) return workers, connections - def run(self, pipeline, progress) -> Statistics: - workers, connections = self._start_workers(pipeline) + def run(self, pipeline, outfiles: OutputFiles, progress) -> Statistics: + workers, connections = self._start_workers(pipeline, outfiles) writers = [] - for f in self._outfiles: + for f in outfiles: writers.append(OrderedChunkWriter(f)) stats = Statistics() while connections: @@ -351,6 +354,7 @@ def run(self, pipeline, progress) -> Statistics: w.join() self._reader_process.join() progress.close() + outfiles.close() return stats @staticmethod @@ -372,7 +376,7 @@ def _try_receive(connection): return result def close(self) -> None: - self._outfiles.close() + pass class SerialPipelineRunner(PipelineRunner): @@ -383,15 +387,15 @@ class SerialPipelineRunner(PipelineRunner): def __init__( self, infiles: InputFiles, - outfiles: OutputFiles, ): self._infiles = infiles - self._outfiles = outfiles - def run(self, pipeline: Pipeline, progress: Progress) -> Statistics: + def run( + self, pipeline: Pipeline, outfiles: OutputFiles, progress: Progress + ) -> Statistics: try: (n, total1_bp, total2_bp) = pipeline.process_reads( - self._infiles, self._outfiles, progress=progress + self._infiles, outfiles, progress=progress ) finally: pipeline.close() @@ -403,17 +407,14 @@ def run(self, pipeline: Pipeline, progress: Progress) -> Statistics: return Statistics().collect(n, total1_bp, total2_bp, modifiers, pipeline._steps) def close(self): - pass + self._infiles.close() -def run_pipeline( - pipeline: Pipeline, +def make_runner( inpaths: InputPaths, - outfiles: OutputFiles, cores: int, - progress: Union[bool, Progress, None] = None, buffer_size: Optional[int] = None, -) -> Statistics: +) -> PipelineRunner: """ Run a pipeline. @@ -421,31 +422,18 @@ def run_pipeline( Args: inpaths: - outfiles: cores: number of cores to run the pipeline on (this is actually the number of worker processes, there will be one extra process for reading the input file(s)) - progress: Set to False for no progress bar, True for Cutadapt’s default progress bar, - or use an object that supports .update() and .close() (e.g. a tqdm instance) buffer_size: Forwarded to `ParallelPipelineRunner()`. Ignored if cores is 1. - - Returns: - A Statistics object """ - if progress is None or progress is False: - progress = DummyProgress() - elif progress is True: - progress = Progress() runner: PipelineRunner if cores > 1: runner = ParallelPipelineRunner( inpaths, - outfiles, n_workers=cores, buffer_size=buffer_size, ) else: - runner = SerialPipelineRunner(inpaths.open(), outfiles) + runner = SerialPipelineRunner(inpaths.open()) - with runner: - statistics = runner.run(pipeline, progress) - return statistics + return runner diff --git a/tests/test_api.py b/tests/test_api.py index 5850eaf9..f9e19369 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -10,8 +10,9 @@ import json import os -from cutadapt.runners import run_pipeline +from cutadapt.runners import make_runner from cutadapt.steps import InfoFileWriter, PairedSingleEndStep +from cutadapt.utils import DummyProgress from utils import datapath @@ -64,7 +65,9 @@ def test_pipeline_single(tmp_path, cores): pipeline.minimum_length = (10,) pipeline.discard_untrimmed = True inpaths = InputPaths(datapath("small.fastq")) - stats = run_pipeline(pipeline, inpaths, outfiles, cores=cores) + runner = make_runner(inpaths, cores) + with runner: + stats = runner.run(pipeline, outfiles, DummyProgress()) assert stats is not None assert info_path.exists() json.dumps(stats.as_json()) @@ -115,7 +118,9 @@ def test_pipeline_paired(tmp_path, cores): pipeline.minimum_length = (10, None) pipeline.discard_untrimmed = True - stats = run_pipeline(pipeline, inpaths, outfiles, cores=cores, progress=True) + runner = make_runner(inpaths, cores=cores) + with runner: + stats = runner.run(pipeline, outfiles, DummyProgress()) assert stats is not None assert info_path.exists() _ = stats.as_json()