Skip to content

Commit

Permalink
Merge pull request #287 from lilab-bcb/batch_correct
Browse files Browse the repository at this point in the history
Harmony data integration accepts multiple attributes as batch key
  • Loading branch information
yihming authored Jan 20, 2024
2 parents 40e6ef5 + 525af8d commit 2f4b507
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
10 changes: 5 additions & 5 deletions pegasus/tools/batch_correction.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
import numpy as np
import pandas as pd
from typing import Union
from typing import Union, List
from pegasusio import UnimodalData, MultimodalData

from pegasus.tools import select_features, X_from_rep, check_batch_key
Expand All @@ -16,7 +16,7 @@
@timer(logger=logger)
def run_harmony(
data: Union[MultimodalData, UnimodalData],
batch: str = "Channel",
batch: Union[str, List[str]] = "Channel",
rep: str = "pca",
n_comps: int = None,
n_jobs: int = -1,
Expand All @@ -34,8 +34,8 @@ def run_harmony(
data: ``MultimodalData``.
Annotated data matrix with rows for cells and columns for genes.
batch: ``str``, optional, default: ``"Channel"``.
Which attribute in data.obs field represents batches, default is "Channel".
batch: ``str`` or ``List[str]``, optional, default: ``"Channel"``.
Which attribute in data.obs field represents batches, default is "Channel". If using multiple attributes, specify their names in a list.
rep: ``str``, optional, default: ``"pca"``.
Which representation to use as input of Harmony, default is PCA.
Expand All @@ -54,7 +54,7 @@ def run_harmony(
use_gpu: ``bool``, optional, default: ``False``.
If ``True``, use GPU if available. Otherwise, use CPU only.
max_iter_harmony: ``int``, optional, default: ``10``.
Maximum iterations on running Harmony if not converged.
Expand Down
23 changes: 14 additions & 9 deletions pegasus/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,19 +157,24 @@ def simulate_doublets(X: Union[csr_matrix, np.ndarray], sim_doublet_ratio: float
return results, doublet_indices


def check_batch_key(data: Union[MultimodalData, UnimodalData], batch: str, warning_msg: str) -> bool:
def check_batch_key(data: Union[MultimodalData, UnimodalData], batch: Union[str, List[str]], warning_msg: str) -> bool:
if batch is None:
return False

if batch not in data.obs:
logger.warning(f"Batch key {batch} does not exist. {warning_msg}")
return False
else:
if not is_categorical_dtype(data.obs[batch]):
data.obs[batch] = pd.Categorical(data.obs[batch].values)
if data.obs[batch].cat.categories.size == 1:
logger.warning(f"Batch key {batch} only contains one batch. {warning_msg}")
if isinstance(batch, str):
batch = [batch]

for bat in batch:
if bat not in data.obs:
logger.warning(f"Batch key {bat} does not exist. {warning_msg}")
return False
else:
if not is_categorical_dtype(data.obs[bat]):
data.obs[bat] = pd.Categorical(data.obs[bat].values)
if data.obs[bat].cat.categories.size == 1:
logger.warning(f"Batch key {bat} only contains one batch. {warning_msg}")
return False

return True


Expand Down

0 comments on commit 2f4b507

Please sign in to comment.