Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CZI-tiledb SCVI Custom Dataloader additions #1339

Open
wants to merge 2 commits into
base: ebezzi/census-scvi-datamodule
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""An API to facilitate use of PyTorch ML training with data from the CZI Science CELLxGENE Census."""

from .pytorch import Encoder, ExperimentDataPipe, Stats, experiment_dataloader
from .datamodule import CensusSCVIDataModule

__all__ = [
"Stats",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

# type: ignore
import functools
from typing import List

import numpy.typing as npt
import pandas as pd
import numpy as np
import torch
from lightning.pytorch import LightningDataModule

Expand Down Expand Up @@ -53,6 +56,37 @@ def classes_(self) -> List[str]:
return self._encoder.classes_


class LabelEncoder(BatchEncoder):
"""An encoder that concatenates and encodes several obs columns as label +
uses a string as missing observation label."""

def __init__(self, cols: list[str], unlabeled_category: str = "Unknown", name: str = "label"):
super().__init__(cols, name)
self._unlabeled_category = unlabeled_category

@property
def unlabeled_category(self) -> str:
"""Name of the unlabeled_category."""
return self._unlabeled_category

def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""Transform the obs DataFrame into a DataFrame of encoded values + unlabeled category."""
arr = self._join_cols(df)
for unique_item in arr.unique():
if unique_item not in self._encoder.classes_:
arr = [self.unlabeled_category if x == unique_item else x for x in arr]
return self._encoder.transform(arr) # type: ignore

def inverse_transform(self, encoded_values: npt.ArrayLike) -> npt.ArrayLike:
"""Inverse transform the encoded values back to the original values."""
return self._encoder.inverse_transform(encoded_values) # type: ignore

def fit(self, obs: pd.DataFrame) -> None:
"""Fit the encoder with obs + unlabeled category."""
arr = self._join_cols(obs)
self._encoder.fit(np.append(arr.unique(),self.unlabeled_category))


class CensusSCVIDataModule(LightningDataModule):
"""Lightning data module for training an scVI model using the ExperimentDataPipe.

Expand All @@ -63,6 +97,10 @@ class CensusSCVIDataModule(LightningDataModule):
:class:`~cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`.
batch_keys
List of obs column names concatenated to form the batch column.
label_keys
List of obs column names concatenated to form the label column.
unlabeled_category
Value used for unlabeled cells in `labels_key` used to set up CZI datamodule with scvi.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Value used for unlabeled cells in `labels_key` used to set up CZI datamodule with scvi.
Value used for unlabeled cells in `labels_key` used to set up `CensusSCVIDataModule`.

train_size
Fraction of data to use for training.
split_seed
Expand All @@ -83,6 +121,8 @@ def __init__(
self,
*args,
batch_keys: list[str] | None = None,
label_keys: list[str] | None = None,
unlabeled_category: str | None = "Unknown",
train_size: float | None = None,
split_seed: int | None = None,
dataloader_kwargs: dict[str, any] | None = None,
Expand All @@ -92,10 +132,38 @@ def __init__(
self.datapipe_args = args
self.datapipe_kwargs = kwargs
self.batch_keys = batch_keys
self.label_keys = label_keys
self.unlabeled_category = unlabeled_category
self.train_size = train_size
self.split_seed = split_seed
self.dataloader_kwargs = dataloader_kwargs or {}

@property
def unlabeled_category(self) -> str:
"""String assigned to unlabeled cells."""
if not hasattr(self, "_unlabeled_category"):
raise AttributeError("`unlabeled_category` not set.")
return self._unlabeled_category

@unlabeled_category.setter
def unlabeled_category(self, value: str | None):
if not (value is None or isinstance(value, str)):
raise ValueError("`unlabeled_category` must be a string or None.")
self._unlabeled_category = value

@property
def label_keys(self) -> list[str]:
"""List of obs column names concatenated to form the label column."""
if not hasattr(self, "_label_keys"):
raise AttributeError("`label_keys` not set.")
return self._label_keys

@label_keys.setter
def label_keys(self, value: list[str] | None):
if not (value is None or isinstance(value, list)):
raise ValueError("`label_keys` must be a list of strings or None.")
self._label_keys = value

@property
def batch_keys(self) -> list[str]:
"""List of obs column names concatenated to form the batch column."""
Expand All @@ -122,6 +190,19 @@ def obs_column_names(self) -> list[str]:
self._obs_column_names = obs_column_names
return self._obs_column_names

@property
def obs_label_names(self) -> list[str]:
"""Passed to :class:`~cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`."""
if hasattr(self, "_obs_label_names"):
return self._obs_label_names

obs_label_names = []
if self.label_keys is not None:
obs_label_names.extend(self.label_keys)

self._obs_label_names = obs_label_names
return self._obs_label_names

@property
def split_seed(self) -> int:
"""Seed for data split."""
Expand Down Expand Up @@ -170,10 +251,14 @@ def weights(self) -> dict[str, float]:
def datapipe(self) -> ExperimentDataPipe:
"""Experiment data pipe."""
if not hasattr(self, "_datapipe"):
encoder = BatchEncoder(self.obs_column_names)
batch_encoder = BatchEncoder(self.obs_column_names)
encoders_list = [batch_encoder]
if self.label_keys is not None:
label_encoder = LabelEncoder(self.obs_label_names, self.unlabeled_category)
encoders_list.append(label_encoder)
self._datapipe = ExperimentDataPipe(
*self.datapipe_args,
encoders=[encoder],
encoders=encoders_list,
**self.datapipe_kwargs,
)
return self._datapipe
Expand Down Expand Up @@ -218,6 +303,13 @@ def n_batch(self) -> int:
"""
return self.get_n_classes("batch")

@property
def n_label(self) -> int:
"""Number of unique labels (after concatenation of ``label_keys``).
Necessary in scvi-tools so that the model knows how to one-hot encode labels.
"""
return self.get_n_classes("label")

def get_n_classes(self, key: str) -> int:
"""Return the number of classes for a given obs column."""
return len(self.datapipe.obs_encoders[key].classes_)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import abc
import gc
import itertools
Expand Down Expand Up @@ -577,7 +579,7 @@ def __init__(
experimental
"""
self.exp_uri = experiment.uri
self.aws_region = experiment.context.tiledb_ctx.config().get("vfs.s3.region")
self.aws_region = experiment.context.tiledb_ctx#.config().get("vfs.s3.region")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a bit suspicious - can you remind me what exactly required this change?

self.measurement_name = measurement_name
self.layer_name = X_name
self.obs_query = obs_query
Expand Down
Loading