Skip to content

Commit

Permalink
limit memory usage in read filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
jykr committed Oct 5, 2023
1 parent 6dd4c1f commit 32cc61a
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 61 deletions.
130 changes: 71 additions & 59 deletions bean/mapping/GuideEditCounter.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,17 +190,17 @@ def check_filter_fastq(self):
os.path.basename(self.R2_filename).replace(".fastq", "").replace(".gz", "")
+ "_filtered.fastq.gz"
)
if path.exists(self.filtered_R1_filename) and path.exists(
self.filtered_R2_filename
):
warn("Using preexisting filtered file")
else:
self._check_names_filter_fastq()
self._check_names_filter_fastq()
# if (
# path.exists(self.filtered_R1_filename)
# and path.exists(self.filtered_R2_filename)
# and self.reuse_filtered_reads
# ):
# warn("Using preexisting filtered file")
# else:
# self._check_names_filter_fastq()

def get_counts(self):
# infile_R1 = _get_fastq_handle(self.R1_filename)
# infile_R2 = _get_fastq_handle(self.R2_filename)

self.nomatch_R1_filename = self.R1_filename.replace(".fastq", "_nomatch.fastq")
self.nomatch_R2_filename = self.R2_filename.replace(".fastq", "_nomatch.fastq")
self.semimatch_R1_filename = self.R1_filename.replace(
Expand Down Expand Up @@ -493,7 +493,9 @@ def _get_guide_counts_bcmatch_semimatch(
self, bcmatch_layer="X_bcmatch", semimatch_layer="X"
):
self.screen.layers[semimatch_layer] = np.zeros_like((self.screen.X))
R1_iter, R2_iter = self._get_fastq_iterators()
R1_iter, R2_iter = self._get_fastq_iterators(
self.filtered_R1_filename, self.filtered_R2_filename
)
if self.keep_intermediate:
outfile_R1_nomatch, outfile_R2_nomatch = self._get_fastq_handle("nomatch")
outfile_R1_semimatch, outfile_R2_semimatch = self._get_fastq_handle(
Expand Down Expand Up @@ -712,23 +714,23 @@ def _get_fastq_handle(

return (R1_handle, R2_handle)

def _get_fastq_iterators(self):
R1_handle = _get_fastq_handle(self.R1_filename)
R2_handle = _get_fastq_handle(self.R2_filename)
def _get_fastq_iterators(self, R1_filename=None, R2_filename=None):
if R1_filename is None:
R1_filename = self.R1_filename
if R2_filename is None:
R2_filename = self.R2_filename
R1_handle = _get_fastq_handle(R1_filename)
R2_handle = _get_fastq_handle(R2_filename)

R1_iterator = FastqPhredIterator(R1_handle)
R2_iterator = FastqPhredIterator(R2_handle)

return (R1_iterator, R2_iterator)

def _get_seq_records(self):
def _get_seq_handles(self):
R1_handle = _get_fastq_handle(self.R1_filename)
R2_handle = _get_fastq_handle(self.R2_filename)
R1 = list(SeqIO.parse(R1_handle, "fastq"))
R2 = list(SeqIO.parse(R2_handle, "fastq"))
R1_handle.close()
R2_handle.close()
return (R1, R2)
return (R1_handle, R2_handle)

def _check_names_filter_fastq(self, filter_by_qual=False):
if self.min_average_read_quality > 0 or self.min_single_bp_quality > 0:
Expand All @@ -741,52 +743,62 @@ def _check_names_filter_fastq(self, filter_by_qual=False):
f"In the filtering, bases up to position {self.qend_R1} of R1 and {self.qend_R2} of R2 are considered."
)

R1, R2 = self._get_seq_records()
info("Done loading reads for quality filtering")
_check_readname_match(R1, R2)
if filter_by_qual:
self.n_reads_after_filtering = self._filter_read_quality(R1, R2)
if self.n_reads_after_filtering == 0:
raise NoReadsAfterQualityFiltering(
"No reads in input or no reads survived the average or single bp quality filtering."
)
else:
info(
"Number of reads in input:%d\tNumber of reads after filtering:%d\n"
% (self.n_total_reads, self.n_reads_after_filtering)
)
else:
self.n_reads_after_filtering = self.n_total_reads

def _filter_read_quality(self, R1=None, R2=None) -> int:
R1_filtered = gzip.open(self.filtered_R1_filename, "w+")
R2_filtered = gzip.open(self.filtered_R2_filename, "w+")

if R1 is None or R2 is None:
R1, R2 = self._get_seq_records()
R1_iter, R2_iter = self._get_fastq_iterators()
(
self.n_reads_after_filtering,
self.n_total_reads,
) = self._check_readname_match_and_filter_quality(
R1_iter, R2_iter, filter_by_qual
)

n_reads_after_filtering = 0
for i, R1_record in enumerate(R1):
R2_record = R2[i]

R1_quality_pass = _read_is_good_quality(
R1_record,
self.min_average_read_quality,
self.min_single_bp_quality,
self.qend_R1,
if self.n_reads_after_filtering == 0:
raise NoReadsAfterQualityFiltering(
"No reads in input or no reads survived the average or single bp quality filtering."
)
R2_quality_pass = _read_is_good_quality(
R2_record,
self.min_average_read_quality,
self.min_single_bp_quality,
self.qend_R2,
else:
info(
"Number of reads in input:%d\tNumber of reads after filtering:%d\n"
% (self.n_total_reads, self.n_reads_after_filtering)
)

def _check_readname_match_and_filter_quality(
self, R1_iter, R2_iter, filter_by_qual=False
) -> Tuple[int, int]:
R1_filtered = gzip.open(self.filtered_R1_filename, "wt+")
R2_filtered = gzip.open(self.filtered_R2_filename, "wt+")

n_reads_after_filtering = 0
n_total_reads = 0
for R1_record, R2_record in zip(R1_iter, R2_iter):
n_total_reads += 1
if R1_record.name != R2_record.name:
raise InputFileError(
"R1 and R2 read discordance in read {} and {}".format(
R1_record.name, R2_record.name
)
)
if filter_by_qual:
R1_quality_pass = _read_is_good_quality(
R1_record,
self.min_average_read_quality,
self.min_single_bp_quality,
self.qend_R1,
)
R2_quality_pass = _read_is_good_quality(
R2_record,
self.min_average_read_quality,
self.min_single_bp_quality,
self.qend_R2,
)
else:
R1_quality_pass = True
R2_quality_pass = True
if R1_quality_pass and R2_quality_pass:
n_reads_after_filtering += 1
R1_filtered.write(R1.format("fastq"))
R2_filtered.write(R2.format("fastq"))
return n_reads_after_filtering
R1_filtered.write(R1_record.format("fastq"))
R2_filtered.write(R2_record.format("fastq"))

return n_reads_after_filtering, n_total_reads

def _write_start_log(self):
try:
Expand Down
4 changes: 2 additions & 2 deletions bin/bean-count-samples
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def count_sample(R1: str, R2: str, sample_id: str, args: argparse.Namespace):
f"Done for {sample_id}. \n\
Output written at {counter.output_dir}.h5ad"
)

del counter
return screen


Expand Down Expand Up @@ -178,7 +178,7 @@ def main():
sample_tbl = pd.read_csv(args.input) # R1_filepath, R2_filepath, sample_name
sample_tbl_input = sample_tbl.iloc[:, :3]
sample_info_tbl = sample_tbl.iloc[:, 2:].set_index(sample_tbl.columns[2])
with Pool(processes=args.threads) as p:
with Pool(processes=args.threads, maxtasksperchild=1) as p:
result = p.starmap(
count_sample,
[
Expand Down

0 comments on commit 32cc61a

Please sign in to comment.