From 85dcde378a0ba639d21e044a9aa8ce6d30ff4197 Mon Sep 17 00:00:00 2001 From: Jayoung Ryu Date: Thu, 19 Oct 2023 14:36:55 -0400 Subject: [PATCH] troubleshoot filtering; faster df concatenation --- bean/annotate/filter_alleles.py | 11 +++++++---- bean/annotate/translate_allele.py | 19 +++++++++++++++---- bean/annotate/utils.py | 20 +++++++++++++++++++- bean/framework/Edit.py | 2 ++ bean/framework/ReporterScreen.py | 11 ++++++++--- bean/plotting/editing_patterns.py | 8 ++++++-- bin/bean-filter | 1 + notebooks/sample_quality_report.ipynb | 2 +- 8 files changed, 59 insertions(+), 15 deletions(-) diff --git a/bean/annotate/filter_alleles.py b/bean/annotate/filter_alleles.py index 2f4410c..99c93fe 100644 --- a/bean/annotate/filter_alleles.py +++ b/bean/annotate/filter_alleles.py @@ -9,6 +9,7 @@ from ..framework.Edit import Allele from ..framework.AminoAcidEdit import CodingNoncodingAllele from ._supporting_fn import map_alleles_to_filtered +from .utils import fast_concat def sum_column_groups(mat, column_index_list): @@ -517,10 +518,10 @@ def _map_alleles_to_filtered( .rename(columns={"allele_mapped": allele_col}) .groupby(["guide", allele_col]) .sum() - ) + ).reset_index() mapped_allele_counts.append(guide_raw_counts) - res = pd.concat(mapped_allele_counts).reset_index() + res = fast_concat(mapped_allele_counts).reset_index(drop=True) res = res.loc[res[allele_col].map(bool), :] return res @@ -579,8 +580,10 @@ def _distribute_alleles_to_filtered( res, index=guide_filtered_counts.index, columns=guide_filtered_counts.columns, - ) + ).reset_index() mapped_allele_counts.append(added_counts) - res = pd.concat(mapped_allele_counts).reset_index() + # @TODO: these lines are taking very long? + print("Contatenating allele counts...") + res = fast_concat(mapped_allele_counts).reset_index(drop=True) res = res.loc[res[allele_col].map(bool), :] return res diff --git a/bean/annotate/translate_allele.py b/bean/annotate/translate_allele.py index 308ad7c..dc174c2 100644 --- a/bean/annotate/translate_allele.py +++ b/bean/annotate/translate_allele.py @@ -346,6 +346,8 @@ def get_aa_change( self, allele: Allele, include_synonymous: bool = True ) -> CodingNoncodingAllele: # sourcery skip: use-named-expression """Finds overlapping CDS and call the same function for the CDS, else return CodingNonCodingAllele with no translated allele.""" + if len(allele.edits) == 0: + return CodingNoncodingAllele.from_alleles(nt_allele=allele) chrom, start, end = allele.get_range() overlapping_cds = find_overlap(chrom, start, end, self.cds_ranges) if overlapping_cds: @@ -520,6 +522,18 @@ def filter_nt_alleles(cn_allele_df: pd.DataFrame, pos_include: Iterable[int]): return alleles +def strsplit_edit(edit_str): + if len(edit_str.split(":")) == 3: + chrom, pos, transition = edit_str.split(":") + elif len(edit_str.split(":")) == 2: + pos, transition = edit_str.split(":") + chrom = None + else: + raise ValueError(f"{edit_str} is not in the correct format.") + ref, alt = transition.split(">") + return chrom, pos, ref, alt + + def annotate_edit( edit_info: pd.DataFrame, edit_col="edit", @@ -538,10 +552,7 @@ def annotate_edit( edit_info["group"] = "" edit_info["int_pos"] = -1 if "pos" not in edit_info.columns: - edit_info["pos"], transition = zip(*(edit_info[edit_col].str.split(":"))) - edit_info["ref"], edit_info["alt"] = zip( - *(pd.Series(transition).str.split(">")) - ) + edit_info["chrom"], edit_info["pos"], edit_info["ref"], edit_info["alt"] = zip(*(edit_info[edit_col].map(strsplit_edit))) edit_info.loc[edit_info.pos.map(lambda s: s.startswith("A")), "coding"] = "coding" edit_info.loc[ edit_info.pos.map(lambda s: not s.startswith("A")), "coding" diff --git a/bean/annotate/utils.py b/bean/annotate/utils.py index dd8bc8b..c02cf90 100644 --- a/bean/annotate/utils.py +++ b/bean/annotate/utils.py @@ -1,9 +1,10 @@ import os import sys import requests -from typing import Optional +from typing import Optional, List import argparse import pandas as pd +from itertools import chain import logging logging.basicConfig( @@ -19,6 +20,23 @@ info = logging.info +def fast_flatten(input_list): + return list(chain.from_iterable(input_list)) + + +def fast_concat(df_list: List[pd.DataFrame]): + """Faster concatenation of many dataframes from + https://gist.github.com/TariqAHassan/fc77c00efef4897241f49e61ddbede9e + """ + colnames = df_list[0].columns + df_dict = dict.fromkeys(colnames, []) + for col in colnames: + extracted = (df[col] for df in df_list) + # Flatten and save to df_dict + df_dict[col] = fast_flatten(extracted) + return pd.DataFrame.from_dict(df_dict)[colnames] + + def find_overlap( chrom: str, start: int, end: int, range_df: pd.DataFrame ) -> Optional[str]: diff --git a/bean/framework/Edit.py b/bean/framework/Edit.py index 94dba7f..0a51fb5 100644 --- a/bean/framework/Edit.py +++ b/bean/framework/Edit.py @@ -168,6 +168,8 @@ def match_str(cls, allele_str): def get_range(self): """Returns genomic range of the edits in the allele""" + if len(self.edits) == 0: + return None return ( self.chrom, min(edit.pos for edit in self.edits), diff --git a/bean/framework/ReporterScreen.py b/bean/framework/ReporterScreen.py index 2bd7e0b..05eeb3d 100644 --- a/bean/framework/ReporterScreen.py +++ b/bean/framework/ReporterScreen.py @@ -850,9 +850,14 @@ def write(self, out_path): adata.uns[k]["edit"].iloc[0], Edit ): adata.uns[k].edit = adata.uns[k].edit.map(str) - for c in [colname for colname in v.columns if "allele" in colname]: - if isinstance(v[c].iloc[0], (Allele, CodingNoncodingAllele)): - adata.uns[k].loc[:, c] = adata.uns[k][c].map(str) + try: + for c in [ + colname for colname in v.columns if "allele" in str(colname) + ]: + if isinstance(v[c].iloc[0], (Allele, CodingNoncodingAllele)): + adata.uns[k].loc[:, c] = adata.uns[k][c].map(str) + except TypeError as e: + raise TypeError(f"error with {e}: {k, v} cannot be written") super(ReporterScreen, adata).write(out_path) diff --git a/bean/plotting/editing_patterns.py b/bean/plotting/editing_patterns.py index 7805f6c..e621fcd 100644 --- a/bean/plotting/editing_patterns.py +++ b/bean/plotting/editing_patterns.py @@ -60,7 +60,11 @@ def _add_absent_edits( def get_edit_rates( - bdata, edit_count_key="edit_counts", add_absent=True, adjust_spacer_pos: bool = True + bdata, + edit_count_key="edit_counts", + add_absent=True, + adjust_spacer_pos: bool = True, + reporter_column: str = "reporter", ): """ Obtain position- and context-wise editing rate (context: base preceding the target base position). @@ -113,7 +117,7 @@ def get_edit_rates( ) edit_rates_agg.rel_pos = edit_rates_agg.rel_pos.astype(int) edit_rates_agg["context"] = edit_rates_agg.apply( - lambda row: bdata.guides.loc[row.guide, "Reporter"][ + lambda row: bdata.guides.loc[row.guide, reporter_column][ row.rel_pos - 1 : row.rel_pos + 1 ], axis=1, diff --git a/bin/bean-filter b/bin/bean-filter index e31eeec..79ef604 100644 --- a/bin/bean-filter +++ b/bin/bean-filter @@ -70,6 +70,7 @@ if __name__ == "__main__": f"Filtered down to {len(bdata.uns[f'{allele_df_keys[-1]}_spacer'])} alleles." ) allele_df_keys.append(f"{allele_df_keys[-1]}_spacer") + bdata.write(f"{args.output_prefix}.tmp.h5ad") if len(bdata.uns[allele_df_keys[-1]]) > 0 and args.filter_window: info( diff --git a/notebooks/sample_quality_report.ipynb b/notebooks/sample_quality_report.ipynb index 3c2b9dd..ca91049 100644 --- a/notebooks/sample_quality_report.ipynb +++ b/notebooks/sample_quality_report.ipynb @@ -116,7 +116,7 @@ "metadata": {}, "outputs": [], "source": [ - "bdata.samples[[\"rep\", condition_label]] = bdata.samples.index.to_series().str.split(\"_\", expand=True)" + "#bdata.samples[[replicate_label, condition_label]] = bdata.samples.index.to_series().str.split(\"_\", expand=True)" ] }, {