Skip to content

Commit

Permalink
run_pipeline → make_pipeline
Browse files Browse the repository at this point in the history
Create the runner before building the actual pipeline. This also starts the
ReaderProcess earlier so that we can get the input file format.
  • Loading branch information
marcelm committed Jan 28, 2024
1 parent 2436abe commit 3e6a428
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 61 deletions.
43 changes: 22 additions & 21 deletions src/cutadapt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
)
from cutadapt.report import full_report, minimal_report, Statistics
from cutadapt.pipeline import SingleEndPipeline, PairedEndPipeline
from cutadapt.runners import run_pipeline
from cutadapt.runners import make_runner
from cutadapt.files import InputPaths, OutputFiles, FileOpener
from cutadapt.steps import (
InfoFileWriter,
Expand Down Expand Up @@ -1128,26 +1128,27 @@ def main(cmdlineargs, default_outfile=sys.stdout.buffer) -> Statistics:
adapters, adapters2 = adapters_from_args(args)
log_adapters(adapters, adapters2 if paired else None)

adapter_names: List[Optional[str]] = [a.name for a in adapters]
adapter_names2: List[Optional[str]] = [a.name for a in adapters2]
outfiles = open_output_files(
args,
default_outfile,
file_opener,
adapter_names,
adapter_names2,
proxied=cores > 1,
)
pipeline = make_pipeline_from_args(args, outfiles, paired, adapters, adapters2)
logger.info(
"Processing %s reads on %d core%s ...",
{False: "single-end", True: "paired-end"}[pipeline.paired],
cores,
"s" if cores > 1 else "",
)
stats = run_pipeline(
pipeline, input_paths, outfiles, cores, progress, args.buffer_size
)
with make_runner(input_paths, cores, args.buffer_size) as runner:
adapter_names: List[Optional[str]] = [a.name for a in adapters]
adapter_names2: List[Optional[str]] = [a.name for a in adapters2]
outfiles = open_output_files(
args,
default_outfile,
file_opener,
adapter_names,
adapter_names2,
proxied=cores > 1,
)
pipeline = make_pipeline_from_args(
args, outfiles, paired, adapters, adapters2
)
logger.info(
"Processing %s reads on %d core%s ...",
{False: "single-end", True: "paired-end"}[pipeline.paired],
cores,
"s" if cores > 1 else "",
)
stats = runner.run(pipeline, outfiles, progress)
except KeyboardInterrupt:
if args.debug:
raise
Expand Down
62 changes: 25 additions & 37 deletions src/cutadapt/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from abc import ABC, abstractmethod
from contextlib import ExitStack
from multiprocessing.connection import Connection
from typing import Any, List, Optional, Tuple, Sequence, Iterator, TYPE_CHECKING, Union
from typing import Any, List, Optional, Tuple, Sequence, Iterator, TYPE_CHECKING

import dnaio

from cutadapt.files import InputFiles, OutputFiles, InputPaths, xopen_rb_raise_limit
from cutadapt.pipeline import Pipeline
from cutadapt.report import Statistics
from cutadapt.utils import Progress, DummyProgress
from cutadapt.utils import Progress

logger = logging.getLogger()

Expand Down Expand Up @@ -235,8 +235,11 @@ class PipelineRunner(ABC):
"""

@abstractmethod
def run(self, pipeline, progress) -> Statistics:
pass
def run(self, pipeline, outfiles: OutputFiles, progress: Progress) -> Statistics:
"""
progress: Use an object that supports .update() and .close() such
as DummyProgress, cutadapt.utils.Progress or a tqdm instance
"""

@abstractmethod
def close(self):
Expand Down Expand Up @@ -275,14 +278,12 @@ class ParallelPipelineRunner(PipelineRunner):
def __init__(
self,
inpaths: InputPaths,
outfiles: OutputFiles,
n_workers: int,
buffer_size: Optional[int] = None,
):
self._n_workers = n_workers
self._need_work_queue: multiprocessing.Queue = mpctx.Queue()
self._buffer_size = 4 * 1024**2 if buffer_size is None else buffer_size
self._outfiles = outfiles
self._inpaths = inpaths
# the workers read from these connections
connections = [mpctx.Pipe(duplex=False) for _ in range(self._n_workers)]
Expand All @@ -303,7 +304,9 @@ def __init__(
self._reader_process.daemon = True
self._reader_process.start()

def _start_workers(self, pipeline) -> Tuple[List[WorkerProcess], List[Connection]]:
def _start_workers(
self, pipeline, outfiles
) -> Tuple[List[WorkerProcess], List[Connection]]:
workers = []
connections = []
for index in range(self._n_workers):
Expand All @@ -313,7 +316,7 @@ def _start_workers(self, pipeline) -> Tuple[List[WorkerProcess], List[Connection
index,
pipeline,
self._inpaths,
self._outfiles,
outfiles,
self._connections[index],
conn_w,
self._need_work_queue,
Expand All @@ -323,10 +326,10 @@ def _start_workers(self, pipeline) -> Tuple[List[WorkerProcess], List[Connection
workers.append(worker)
return workers, connections

def run(self, pipeline, progress) -> Statistics:
workers, connections = self._start_workers(pipeline)
def run(self, pipeline, outfiles: OutputFiles, progress) -> Statistics:
workers, connections = self._start_workers(pipeline, outfiles)
writers = []
for f in self._outfiles:
for f in outfiles:
writers.append(OrderedChunkWriter(f))
stats = Statistics()
while connections:
Expand All @@ -351,6 +354,7 @@ def run(self, pipeline, progress) -> Statistics:
w.join()
self._reader_process.join()
progress.close()
outfiles.close()
return stats

@staticmethod
Expand All @@ -372,7 +376,7 @@ def _try_receive(connection):
return result

def close(self) -> None:
self._outfiles.close()
pass


class SerialPipelineRunner(PipelineRunner):
Expand All @@ -383,15 +387,15 @@ class SerialPipelineRunner(PipelineRunner):
def __init__(
self,
infiles: InputFiles,
outfiles: OutputFiles,
):
self._infiles = infiles
self._outfiles = outfiles

def run(self, pipeline: Pipeline, progress: Progress) -> Statistics:
def run(
self, pipeline: Pipeline, outfiles: OutputFiles, progress: Progress
) -> Statistics:
try:
(n, total1_bp, total2_bp) = pipeline.process_reads(
self._infiles, self._outfiles, progress=progress
self._infiles, outfiles, progress=progress
)
finally:
pipeline.close()
Expand All @@ -403,49 +407,33 @@ def run(self, pipeline: Pipeline, progress: Progress) -> Statistics:
return Statistics().collect(n, total1_bp, total2_bp, modifiers, pipeline._steps)

def close(self):
pass
self._infiles.close()


def run_pipeline(
pipeline: Pipeline,
def make_runner(
inpaths: InputPaths,
outfiles: OutputFiles,
cores: int,
progress: Union[bool, Progress, None] = None,
buffer_size: Optional[int] = None,
) -> Statistics:
) -> PipelineRunner:
"""
Run a pipeline.
This uses a SerialPipelineRunner if cores is 1 and a ParallelPipelineRunner otherwise.
Args:
inpaths:
outfiles:
cores: number of cores to run the pipeline on (this is actually the number of worker
processes, there will be one extra process for reading the input file(s))
progress: Set to False for no progress bar, True for Cutadapt’s default progress bar,
or use an object that supports .update() and .close() (e.g. a tqdm instance)
buffer_size: Forwarded to `ParallelPipelineRunner()`. Ignored if cores is 1.
Returns:
A Statistics object
"""
if progress is None or progress is False:
progress = DummyProgress()
elif progress is True:
progress = Progress()
runner: PipelineRunner
if cores > 1:
runner = ParallelPipelineRunner(
inpaths,
outfiles,
n_workers=cores,
buffer_size=buffer_size,
)
else:
runner = SerialPipelineRunner(inpaths.open(), outfiles)
runner = SerialPipelineRunner(inpaths.open())

with runner:
statistics = runner.run(pipeline, progress)
return statistics
return runner
11 changes: 8 additions & 3 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
import json
import os

from cutadapt.runners import run_pipeline
from cutadapt.runners import make_runner
from cutadapt.steps import InfoFileWriter, PairedSingleEndStep
from cutadapt.utils import DummyProgress
from utils import datapath


Expand Down Expand Up @@ -64,7 +65,9 @@ def test_pipeline_single(tmp_path, cores):
pipeline.minimum_length = (10,)
pipeline.discard_untrimmed = True
inpaths = InputPaths(datapath("small.fastq"))
stats = run_pipeline(pipeline, inpaths, outfiles, cores=cores)
runner = make_runner(inpaths, cores)
with runner:
stats = runner.run(pipeline, outfiles, DummyProgress())
assert stats is not None
assert info_path.exists()
json.dumps(stats.as_json())
Expand Down Expand Up @@ -115,7 +118,9 @@ def test_pipeline_paired(tmp_path, cores):
pipeline.minimum_length = (10, None)
pipeline.discard_untrimmed = True

stats = run_pipeline(pipeline, inpaths, outfiles, cores=cores, progress=True)
runner = make_runner(inpaths, cores=cores)
with runner:
stats = runner.run(pipeline, outfiles, DummyProgress())
assert stats is not None
assert info_path.exists()
_ = stats.as_json()
Expand Down

0 comments on commit 3e6a428

Please sign in to comment.