diff --git a/README.md b/README.md index a18ad4a..5069fa7 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ File should contain following columns with header. * `R1_filepath`: Path to read 1 `.fastq[.gz]` file * `R2_filepath`: Path to read 1 `.fastq[.gz]` file * `sample_id`: ID of sequencing sample -* `rep [Optional]`: Replicate # of this sample +* `rep [Optional]`: Replicate # of this sample (Should NOT contain `.`) * `bin [Optional]`: Name of the sorting bin * `upper_quantile [Optional]`: FACS sorting upper quantile * `lower_quantile [Optional]`: FACS sorting lower quantile diff --git a/bean/framework/ReporterScreen.py b/bean/framework/ReporterScreen.py index 05eeb3d..5d012c6 100644 --- a/bean/framework/ReporterScreen.py +++ b/bean/framework/ReporterScreen.py @@ -323,8 +323,20 @@ def __getitem__(self, index): new_uns = deepcopy(self.uns) for k, df in adata.uns.items(): if k.startswith("repguide_mask"): - new_uns[k] = df.loc[guides_include, adata.var.rep.unique()] + if "sample_covariates" in adata.uns: + adata.var["_rc"] = adata.var[ + ["rep"] + adata.uns["sample_covariates"] + ].values.tolist() + adata.var["_rc"] = adata.var["_rc"].map( + lambda slist: ".".join(slist) + ) + new_uns[k] = df.loc[guides_include, adata.var._rc.unique()] + adata.var.pop("_rc") + else: + new_uns[k] = df.loc[guides_include, adata.var.rep.unique()] if not isinstance(df, pd.DataFrame): + if k == "sample_covariates": + new_uns[k] = df continue if "guide" in df.columns: if "allele" in df.columns: @@ -892,7 +904,7 @@ def concat(screens: Collection[ReporterScreen], *args, axis=1, **kwargs): if axis == 0: for k in keys: - if k in ["target_base_change", "tiling"]: + if k in ["target_base_change", "tiling", "sample_covariates"]: adata.uns[k] = screens[0].uns[k] continue elif "edit" not in k and "allele" not in k: @@ -902,7 +914,7 @@ def concat(screens: Collection[ReporterScreen], *args, axis=1, **kwargs): if axis == 1: # If combining multiple samples, edit/allele tables should be merged. for k in keys: - if k in ["target_base_change", "tiling"]: + if k in ["target_base_change", "tiling", "sample_covariates"]: adata.uns[k] = screens[0].uns[k] continue elif "edit" not in k and "allele" not in k: diff --git a/bean/preprocessing/data_class.py b/bean/preprocessing/data_class.py index e3c518d..0a16039 100644 --- a/bean/preprocessing/data_class.py +++ b/bean/preprocessing/data_class.py @@ -2,7 +2,7 @@ import abc import logging from dataclasses import dataclass -from typing import Dict, Tuple +from typing import Dict, Tuple, List from xmlrpc.client import Boolean from copy import deepcopy import torch @@ -40,6 +40,7 @@ def __init__( sample_mask_column: str = None, shrink_alpha: bool = False, condition_column: str = "sort", + sample_covariate_column: List[str] = [], control_condition: str = "bulk", accessibility_col: str = None, accessibility_bw_path: str = None, @@ -51,6 +52,7 @@ def __init__( ): """ Args + condition_column: By default, a single condition column, but you can optionally inlcude sample covariate column control_can_be_selected: If True, screen.samples[condition_column] == control_condition can also be included in effect size inference if its condition column is not NA (Currently only suppoted for prolifertion screens). """ # TODO: remove replicate with too small number of (ex. only 1) sorting bin @@ -821,6 +823,7 @@ def __init__( sample_mask_column: str = None, shrink_alpha: bool = False, condition_column: str = "sort", + sample_covariate_column: List[str] = [], control_condition: str = "bulk", lower_quantile_column: str = "lower_quantile", upper_quantile_column: str = "upper_quantile", diff --git a/bean/qc/guide_qc.py b/bean/qc/guide_qc.py index 309a775..1bbf379 100644 --- a/bean/qc/guide_qc.py +++ b/bean/qc/guide_qc.py @@ -22,15 +22,27 @@ def get_outlier_guides_and_mask( abs_RPM_thres: RPM threshold value that will be used to define outlier guides. """ outlier_guides = get_outlier_guides(bdata, condit_col, mad_z_thres, abs_RPM_thres) - outlier_guides[replicate_col] = bdata.samples.loc[ - outlier_guides["sample"], replicate_col - ].values - mask = pd.DataFrame( - index=bdata.guides.index, columns=bdata.samples[replicate_col].unique() - ).fillna(1) + if not isinstance(replicate_col, str): + outlier_guides["_rc"] = bdata.samples.loc[ + outlier_guides["sample"], replicate_col + ].values.tolist() + outlier_guides["_rc"] = outlier_guides["_rc"].map(lambda slist: ".".join(slist)) + else: + outlier_guides[replicate_col] = bdata.samples.loc[ + outlier_guides["sample"], replicate_col + ].values + if isinstance(replicate_col, str): + reps = bdata.samples[replicate_col].unique() + else: + reps = bdata.samples[replicate_col].drop_duplicates().to_records(index=False) + reps = [".".join(slist) for slist in reps] + mask = pd.DataFrame(index=bdata.guides.index, columns=reps).fillna(1) print(outlier_guides) for _, row in outlier_guides.iterrows(): - mask.loc[row["name"], row[replicate_col]] = 0 + mask.loc[ + row["name"], row[replicate_col if isinstance(replicate_col, str) else "_rc"] + ] = 0 + return outlier_guides, mask diff --git a/bean/qc/utils.py b/bean/qc/utils.py index 48945d0..f584783 100644 --- a/bean/qc/utils.py +++ b/bean/qc/utils.py @@ -1,4 +1,5 @@ import distutils +from typing import Union, List import numpy as np import pandas as pd from copy import deepcopy @@ -52,12 +53,23 @@ def parse_args(): type=str, default="rep", ) + parser.add_argument( + "--sample-covariates", + help="Comma-separated list of column names in `bdata.samples` that describes non-selective experimental condition. (drug treatment, etc.)", + type=str, + default=None, + ) parser.add_argument( "--condition-label", help="Label of column in `bdata.samples` that describes experimental condition. (sorting bin, time, etc.)", type=str, default="bin", ) + parser.add_argument( + "--no-editing", + help="Ignore QC about editing. Can be used for QC of other editing modalities.", + action="store_true", + ) parser.add_argument( "--target-pos-col", help="Target position column in `bdata.guides` specifying target edit position in reporter", @@ -143,21 +155,30 @@ def check_args(args): ) args.lfc_cond1 = lfc_conds[0] args.lfc_cond2 = lfc_conds[1] + if args.sample_covariates is not None: + if "," in args.sample_covariates: + args.sample_covariates = args.sample_covariates.split(",") + args.replicate_label = [args.replicate_label] + args.sample_covariates + else: + args.replicate_label = [args.replicate_label, args.sample_covariates] + if args.no_editing: + args.base_edit_data = False + else: + args.base_edit_data = True return args -def _add_dummy_sample(bdata, rep, cond, condition_label: str, replicate_label: str): +def _add_dummy_sample( + bdata, rep, cond, condition_label: str, replicate_label: Union[str, List[str]] +): sample_id = f"{rep}_{cond}" cond_df = deepcopy(bdata.samples) - cond_df[replicate_label] = np.nan - cond_df = cond_df.drop_duplicates() + # cond_df = cond_df.drop_duplicates() cond_row = cond_df.loc[cond_df[condition_label] == cond, :] - if not len(cond_row) == 1: - raise ValueError( - f"Non-unique condition specification in ReporterScreen.samples: {cond_row}" - ) + if len(cond_row) != 1: + cond_row = cond_row.iloc[[0], :] cond_row.index = [sample_id] - cond_row.loc[:, replicate_label] = rep + cond_row[replicate_label] = rep dummy_sample_bdata = ReporterScreen( X=np.zeros((bdata.n_obs, 1)), X_bcmatch=np.zeros((bdata.n_obs, 1)), @@ -175,27 +196,40 @@ def _add_dummy_sample(bdata, rep, cond, condition_label: str, replicate_label: s return bdata -def fill_in_missing_samples(bdata, condition_label: str, replicate_label: str): +def fill_in_missing_samples( + bdata, condition_label: str, replicate_label: Union[str, List[str]] +): """If not all condition exists for every replicate in bdata, fill in fake sample""" added_dummy = False - for rep in bdata.samples[replicate_label].unique(): + if isinstance(replicate_label, str): + rep_list = bdata.samples[replicate_label].unique() + else: + rep_list = ( + bdata.samples[replicate_label].drop_duplicates().to_records(index=False) + ) + # print(rep_list) + for rep in rep_list: for cond in bdata.samples[condition_label].unique(): + if isinstance(replicate_label, str): + rep_samples = bdata.samples[replicate_label] == rep + else: + rep = list(rep) + rep_samples = (bdata.samples[replicate_label] == rep).all(axis=1) if ( - len( - np.where( - (bdata.samples[replicate_label] == rep) - & (bdata.samples[condition_label] == cond) - )[0] - ) + len(np.where(rep_samples & (bdata.samples[condition_label] == cond))[0]) != 1 ): + print(f"Adding dummy samples for {rep}, {cond}") bdata = _add_dummy_sample( bdata, rep, cond, condition_label, replicate_label ) if not added_dummy: added_dummy = True if added_dummy: - bdata = bdata[ - :, bdata.samples.sort_values([replicate_label, condition_label]).index - ] + if isinstance(replicate_label, str): + sort_labels = [replicate_label, condition_label] + else: + sort_labels = replicate_label + [condition_label] + bdata = bdata[:, bdata.samples.sort_values(sort_labels).index] + return bdata diff --git a/bin/bean-qc b/bin/bean-qc index c3eb72e..c83d6f2 100644 --- a/bin/bean-qc +++ b/bin/bean-qc @@ -34,6 +34,7 @@ def main(): ctrl_cond=args.ctrl_cond, exp_id=args.out_report_prefix, recalculate_edits=args.recalculate_edits, + base_edit_data=args.base_edit_data, ), kernel_name="bean_python3", ) diff --git a/notebooks/sample_quality_report.ipynb b/notebooks/sample_quality_report.ipynb index 70b7c44..205d3a3 100644 --- a/notebooks/sample_quality_report.ipynb +++ b/notebooks/sample_quality_report.ipynb @@ -56,7 +56,8 @@ "comp_cond2 = \"bot\"\n", "ctrl_cond = \"bulk\"\n", "recalculate_edits = False\n", - "tiling = None" + "tiling = None\n", + "base_edit_data = True" ] }, { @@ -75,7 +76,9 @@ "outputs": [], "source": [ "if tiling is not None:\n", - " bdata.uns['tiling'] = tiling" + " bdata.uns['tiling'] = tiling\n", + "if not isinstance(replicate_label, str):\n", + " bdata.uns['sample_covariates'] = replicate_label[1:]" ] }, { @@ -208,12 +211,7 @@ "outputs": [], "source": [ "selected_guides = bdata.guides[posctrl_col] == posctrl_val if posctrl_col else ~bdata.guides.index.isnull()\n", - "ax=pt.qc.plot_lfc_correlation(bdata, selected_guides, method=\"Spearman\", cond1=comp_cond1, cond2=comp_cond2, rep_col=replicate_label, compare_col=condition_label, figsize=(10,10))\n", - "\n", - "ax.set_title(\"top/bot LFC correlation, Spearman\")\n", - "plt.yticks(rotation=0) \n", - "plt.xticks(rotation=90) \n", - "plt.show()" + "print(f\"Calculating LFC correlation of {sum(selected_guides)} {'positive control' if posctrl_col else 'all'} guides.\")" ] }, { @@ -221,7 +219,23 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "ax = pt.qc.plot_lfc_correlation(\n", + " bdata,\n", + " selected_guides,\n", + " method=\"Spearman\",\n", + " cond1=comp_cond1,\n", + " cond2=comp_cond2,\n", + " rep_col=replicate_label,\n", + " compare_col=condition_label,\n", + " figsize=(10, 10),\n", + ")\n", + "\n", + "ax.set_title(\"top/bot LFC correlation, Spearman\")\n", + "plt.yticks(rotation=0)\n", + "plt.xticks(rotation=90)\n", + "plt.show()" + ] }, { "cell_type": "code", @@ -243,7 +257,12 @@ "metadata": {}, "outputs": [], "source": [ - "if recalculate_edits or \"edits\" not in bdata.layers.keys() or bdata.layers['edits'].max() == 0:\n", + "if \"target_base_change\" not in bdata.uns or not base_edit_data:\n", + " bdata.uns[\"target_base_change\"] = \"\"\n", + " base_edit_data = False\n", + " print(\"Not a base editing data or target base change not provided. Passing editing-related QC\")\n", + " edit_rate_threshold = -0.1\n", + "elif recalculate_edits or \"edits\" not in bdata.layers.keys() or bdata.layers['edits'].max() == 0:\n", " if 'allele_counts' in bdata.uns.keys():\n", " bdata.uns['allele_counts'] = bdata.uns['allele_counts'].loc[bdata.uns['allele_counts'].allele.map(str) != \"\"]\n", " bdata.get_edit_from_allele()\n", @@ -261,12 +280,22 @@ "metadata": {}, "outputs": [], "source": [ - "if \"edits\" in bdata.layers.keys():\n", + "if \"target_base_change\" not in bdata.uns or not base_edit_data:\n", + " print(\n", + " \"Not a base editing data or target base change not provided. Passing editing-related QC\"\n", + " )\n", + "elif \"edits\" in bdata.layers.keys():\n", + "\n", " bdata.get_guide_edit_rate(\n", + "\n", " editable_base_start=edit_quantification_start_pos,\n", + "\n", " editable_base_end=edit_quantification_end_pos,\n", + "\n", " unsorted_condition_label=ctrl_cond,\n", + "\n", " )\n", + "\n", " be.qc.plot_guide_edit_rates(bdata)" ] }, @@ -276,11 +305,19 @@ "metadata": {}, "outputs": [], "source": [ - "if \"edits\" in bdata.layers.keys():\n", + "if \"target_base_change\" not in bdata.uns or not base_edit_data:\n", + " print(\n", + " \"Not a base editing data or target base change not provided. Passing editing-related QC\"\n", + " )\n", + "elif \"edits\" in bdata.layers.keys():\n", + "\n", " bdata.get_edit_rate(\n", - " editable_base_start = edit_quantification_start_pos, \n", - " editable_base_end=edit_quantification_end_pos\n", + " editable_base_start=edit_quantification_start_pos,\n", + "\n", + " editable_base_end=edit_quantification_end_pos,\n", + "\n", " )\n", + "\n", " be.qc.plot_sample_edit_rates(bdata)" ] }, @@ -329,12 +366,24 @@ "outputs": [], "source": [ "# leave replicate with more than 1 sorting bin data\n", - "rep_n_samples = bdata_filtered.samples.groupby(replicate_label)['mask'].sum()\n", + "rep_n_samples = bdata_filtered.samples.groupby(replicate_label)[\"mask\"].sum()\n", "print(rep_n_samples)\n", "rep_has_too_small_sample = rep_n_samples.loc[rep_n_samples < 2].index.tolist()\n", "rep_has_too_small_sample\n", - "print(f\"Excluding reps {rep_has_too_small_sample} that has less than 2 samples per replicate.\")\n", - "bdata_filtered = bdata_filtered[:, ~bdata_filtered.samples[replicate_label].isin(rep_has_too_small_sample)]" + "print(\n", + " f\"Excluding reps {rep_has_too_small_sample} that has less than 2 samples per replicate.\"\n", + ")\n", + "if isinstance(replicate_label, str):\n", + " samples_include = ~bdata_filtered.samples[replicate_label].isin(\n", + " rep_has_too_small_sample\n", + " )\n", + "else:\n", + " bdata_filtered.samples[\"_rc\"] = bdata_filtered.samples[\n", + " replicate_label\n", + " ].values.tolist()\n", + " samples_include = ~bdata_filtered.samples[\"_rc\"].isin(rep_has_too_small_sample)\n", + "bdata_filtered = bdata_filtered[:, samples_include]\n", + "bdata_filtered.samples.pop(\"_rc\")" ] }, { diff --git a/setup.py b/setup.py index 4b2157f..f41178c 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name="crispr-bean", - version="0.2.9", + version="0.3.0", python_requires=">=3.8.0", author="Jayoung Ryu", author_email="jayoung_ryu@g.harvard.edu", @@ -36,7 +36,7 @@ "numpy", "pandas", "scipy", - "perturb-tools>=0.2.8", + "perturb-tools>=0.3.0", "matplotlib", "seaborn>=0.13.0", "tqdm",