diff --git a/pegasus/tools/batch_correction.py b/pegasus/tools/batch_correction.py index b4521c51..da06e5ae 100644 --- a/pegasus/tools/batch_correction.py +++ b/pegasus/tools/batch_correction.py @@ -1,7 +1,7 @@ import time import numpy as np import pandas as pd -from typing import Union +from typing import Union, List from pegasusio import UnimodalData, MultimodalData from pegasus.tools import select_features, X_from_rep, check_batch_key @@ -16,7 +16,7 @@ @timer(logger=logger) def run_harmony( data: Union[MultimodalData, UnimodalData], - batch: str = "Channel", + batch: Union[str, List[str]] = "Channel", rep: str = "pca", n_comps: int = None, n_jobs: int = -1, @@ -34,8 +34,8 @@ def run_harmony( data: ``MultimodalData``. Annotated data matrix with rows for cells and columns for genes. - batch: ``str``, optional, default: ``"Channel"``. - Which attribute in data.obs field represents batches, default is "Channel". + batch: ``str`` or ``List[str]``, optional, default: ``"Channel"``. + Which attribute in data.obs field represents batches, default is "Channel". If using multiple attributes, specify their names in a list. rep: ``str``, optional, default: ``"pca"``. Which representation to use as input of Harmony, default is PCA. @@ -54,7 +54,7 @@ def run_harmony( use_gpu: ``bool``, optional, default: ``False``. If ``True``, use GPU if available. Otherwise, use CPU only. - + max_iter_harmony: ``int``, optional, default: ``10``. Maximum iterations on running Harmony if not converged. diff --git a/pegasus/tools/utils.py b/pegasus/tools/utils.py index ebaf11bc..d3970847 100644 --- a/pegasus/tools/utils.py +++ b/pegasus/tools/utils.py @@ -157,19 +157,24 @@ def simulate_doublets(X: Union[csr_matrix, np.ndarray], sim_doublet_ratio: float return results, doublet_indices -def check_batch_key(data: Union[MultimodalData, UnimodalData], batch: str, warning_msg: str) -> bool: +def check_batch_key(data: Union[MultimodalData, UnimodalData], batch: Union[str, List[str]], warning_msg: str) -> bool: if batch is None: return False - if batch not in data.obs: - logger.warning(f"Batch key {batch} does not exist. {warning_msg}") - return False - else: - if not is_categorical_dtype(data.obs[batch]): - data.obs[batch] = pd.Categorical(data.obs[batch].values) - if data.obs[batch].cat.categories.size == 1: - logger.warning(f"Batch key {batch} only contains one batch. {warning_msg}") + if isinstance(batch, str): + batch = [batch] + + for bat in batch: + if bat not in data.obs: + logger.warning(f"Batch key {bat} does not exist. {warning_msg}") return False + else: + if not is_categorical_dtype(data.obs[bat]): + data.obs[bat] = pd.Categorical(data.obs[bat].values) + if data.obs[bat].cat.categories.size == 1: + logger.warning(f"Batch key {bat} only contains one batch. {warning_msg}") + return False + return True