Skip to content

Commit

Permalink
Merge pull request #180 from satijalab/fix/scVIIntegration
Browse files Browse the repository at this point in the history
Update scVIIntegration to take `StdAssay`/`SCTAssay` as input
  • Loading branch information
dcollins15 authored Jan 29, 2024
2 parents 17b8d5a + 90d2bc1 commit d9594f6
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 68 deletions.
13 changes: 7 additions & 6 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
Package: SeuratWrappers
Title: Community-Provided Methods and Extensions for the Seurat Object
Version: 0.3.3
Date: 2024-01-19
Version: 0.3.4
Date: 2024-01-23
Authors@R: c(
person(given = 'Rahul', family = 'Satija', email = '[email protected]', role = 'aut', comment = c(ORCID = '0000-0001-9448-8833')),
person(given = 'Andrew', family = 'Butler', email = '[email protected]', role = 'aut', comment = c(ORCID = '0000-0003-3608-0463')),
person(given = 'Paul', family = 'Hoffman', email = 'nygcSatijalab@nygenome.org', role = c('aut', 'cre'), comment = c(ORCID = '0000-0002-7693-8957')),
person(given = 'Tim', family = 'Stuart', email = 'tstuart@nygenome.org', role = 'aut', comment = c(ORCID = '0000-0002-3044-0897')),
person(given = "Saket", family = "Choudhary", email = "schoudhary@nygenome.org", role = "ctb", comment = c(ORCID = "0000-0001-5202-7633")),
person(given = 'David', family = 'Collins', email = 'dcollins@nygenome.org', role = 'ctb', comment = c(ORCID = '0000-0001-9243-7821')),
person(given = "Yuhan", family = "Hao", email = "[email protected]", role = "ctb", comment = c(ORCID = "0000-0002-1810-0822")),
person(given = "Austin", family = "Hartman", email = "[email protected]", role = "ctb", comment = c(ORCID = "0000-0001-7278-1852")),
person(given = 'Paul', family = 'Hoffman', email = '[email protected]', role = c('aut', 'cre'), comment = c(ORCID = '0000-0002-7693-8957')),
person(given = "Gesmira", family = "Molla", email = '[email protected]', role = 'ctb', comment = c(ORCID = '0000-0002-8628-5056')),
person(given = "Saket", family = "Choudhary", email = "[email protected]", role = "ctb", comment = c(ORCID = "0000-0001-5202-7633"))
person(given = 'Rahul', family = 'Satija', email = '[email protected]', role = 'aut', comment = c(ORCID = '0000-0001-9448-8833')),
person(given = 'Tim', family = 'Stuart', email = '[email protected]', role = 'aut', comment = c(ORCID = '0000-0002-3044-0897'))
)
Description: SeuratWrappers is a collection of community-provided methods and
extensions for Seurat, curated by the Satija Lab at NYGC. These methods
Expand Down
177 changes: 136 additions & 41 deletions R/scVI.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,27 @@
#'
NULL

#' Run scVI in seurat5
#' @param object A merged Seurat object
#' @param groups Name of the metadata column to be used as the 'batch_key'
#' @param features features to use
#' @param layers Layers to use
#' scVI Integration
#' @param object A \code{StdAssay} or \code{STDAssay} instance containing
#' merged data
#' @param features Features to integrate
#' @param layers Layers to integrate
#' @param conda_env conda environment to run scVI
#' @param new.reduction Name to store resulting DimReduc object as
#' @param ... Arguments passed to other methods
#' @param new.reduction Name under which to store resulting DimReduc object
#' @param ndims Dimensionality of the latent space
#' @param nlayers Number of hidden layers used for encoder and decoder NNs
#' @param gene_likelihood Distribution to use for modelling expression
#' data: {"zinb", "nb", "poisson"}
#' @param max_epochs Number of passes through the dataset taken while
#' training the model
#' @param ... Unused - currently just capturing parameters passed in from
#' \code{Seurat::IntegrateLayers} intended for other integration methods
#'
#' @export
#'
#' @note This function requires the
#' \href{https://docs.scvi-tools.org/en/stable/installation.html}{\pkg{scvi-tools}} package
#' to be installed
#' \href{https://docs.scvi-tools.org/en/stable/installation.html}{\pkg{scvi-tools}}
#' package to be installed
#'
#' @examples
#' \dontrun{
Expand All @@ -28,63 +35,151 @@ NULL
#' obj <- RunPCA(obj)
#'
#' # After preprocessing, we integrate layers, specifying a conda environment
#' obj <- IntegrateLayers(object = obj, method = scVIIntegration, new.reduction = 'integrated.scvi',
#' conda_env = '../miniconda3/envs/scvi-env', verbose = FALSE)
#' }
#' obj <- IntegrateLayers(
#' object = obj,
#' method = scVIIntegration,
#' new.reduction = "integrated.scvi",
#' conda_env = "../miniconda3/envs/scvi-env",
#' verbose = FALSE
#' )
#'
#' # Alternatively, we can integrate SCTransformed data
#' obj <- SCTransform(object = obj)
#' obj <- IntegrateLayers(object = obj, method = scVIIntegration,
#' orig.reduction = "pca", new.reduction = 'integrated.scvi',
#' assay = "SCT", conda_env = '../miniconda3/envs/scvi-env', verbose = FALSE)
#' obj <- IntegrateLayers(
#' object = obj, method = scVIIntegration,
#' orig.reduction = "pca", new.reduction = "integrated.scvi",
#' assay = "SCT", conda_env = "../miniconda3/envs/scvi-env", verbose = FALSE
#' )
#' }
#'
#' @seealso \href{https://docs.scvi-tools.org/en/stable/tutorials/notebooks/scvi_in_R.html}{scVI}
#'
#' @return A Seurat object with embeddings and loadings

#' @return A single-element named list \code{DimReduc} elements containing
#' the integrated data
scVIIntegration <- function(
object,
groups = NULL,
features = NULL,
layers = 'counts',
layers = "counts",
conda_env = NULL,
new.reduction = 'integrated.dr',
new.reduction = "integrated.dr",
ndims = 30,
nlayers = 2,
gene_likelihood = "nb",
max_epochs = NULL,
...){
...) {

# import python methods from specified conda env
reticulate::use_condaenv(conda_env, required = TRUE)
sc <- reticulate::import('scanpy', convert = FALSE)
scvi <- reticulate::import('scvi', convert = FALSE)
anndata <- reticulate::import('anndata', convert = FALSE)
scipy <- reticulate::import('scipy', convert = FALSE)
object <- JoinLayers(object = object, layers = 'counts')
adata <- sc$AnnData(
X = scipy$sparse$csr_matrix(Matrix::t(LayerData(object, layer = 'counts')[features ,]) ), #scVI requires raw counts
obs = object[[]],
var = object[[DefaultAssay(object)]][[]][features,]
)
scvi$model$SCVI$setup_anndata(adata, batch_key = groups)
model = scvi$model$SCVI(adata = adata,
n_latent = as.integer(x = ndims),
n_layers = as.integer(x = nlayers),
gene_likelihood = gene_likelihood)
sc <- reticulate::import("scanpy", convert = FALSE)
scvi <- reticulate::import("scvi", convert = FALSE)
anndata <- reticulate::import("anndata", convert = FALSE)
scipy <- reticulate::import("scipy", convert = FALSE)

# if `max_epochs` is not set
if (is.null(max_epochs)) {
max_epochs <- reticulate::r_to_py(x = max_epochs)
# convert `NULL` to python's `None`
max_epochs <- reticulate::r_to_py(max_epochs)
} else {
max_epochs <- as.integer(x = max_epochs)
# otherwise make sure it's an int
max_epochs <- as.integer(max_epochs)
}

# build a meta.data-style data.frame indicating the batch for each cell
batches <- .FindBatches(object, layers = layers)
# scVI expects a single counts matrix so we'll join our layers together
# it also expects the raw counts matrix
# TODO: avoid hardcoding this - users can rename their layers arbitrarily
# so there's no gauruntee that the usual naming conventions will be followed
object <- JoinLayers(object = object, layers = "counts")
# setup an `AnnData` python instance
adata <- sc$AnnData(
X = scipy$sparse$csr_matrix(
# TODO: avoid hardcoding per comment above
Matrix::t(LayerData(object, layer = "counts")[features, ])
),
obs = batches,
var = object[[]][features, ]
)
scvi$model$SCVI$setup_anndata(adata, batch_key = "batch")

# initialize and train the model
model <- scvi$model$SCVI(
adata = adata,
n_latent = as.integer(x = ndims),
n_layers = as.integer(x = nlayers),
gene_likelihood = gene_likelihood
)
model$train(max_epochs = max_epochs)
latent = model$get_latent_representation()

# extract the latent representation of the merged data
latent <- model$get_latent_representation()
latent <- as.matrix(latent)
# pull the cell identifiers back out of the `AnnData` instance
# in case anything was sorted under the hood
rownames(latent) <- reticulate::py_to_r(adata$obs$index$values)
# prepend the latent space dimensions with `new.reduction` to
# give the features more readable names
colnames(latent) <- paste0(new.reduction, "_", 1:ncol(latent))
suppressWarnings(latent.dr <- CreateDimReducObject(embeddings = latent, key = new.reduction))

# build a `DimReduc` instance
suppressWarnings(
latent.dr <- CreateDimReducObject(
embeddings = latent,
key = new.reduction
)
)
# to make it easier to add the reduction into a `Seurat` instance
# we'll wrap it up in a named list
output.list <- list(latent.dr)
names(output.list) <- new.reduction

return(output.list)
}

attr(x = scVIIntegration, which = 'Seurat.method') <- 'integration'
attr(x = scVIIntegration, which = "Seurat.method") <- "integration"


#' Builds a data.frame with batch identifiers to use when integrating
#' \code{object}. For \code{SCTAssay}s, batches are split using their
#' model identifiers. For \code{StdAssays}, batches are split by layer.
#'
#' Internal - essentially the same as \code{Seurat:::CreateIntegrationGroups}
#' except that it does not take in a `scale.layer` param.
#'
#' @noRd
#'
#' @param object A \code{SCTAssay} or \code{StdAssays} instance.
#' @param layers Layers in \code{object} to integrate.
#'
#' @return A dataframe indexed on the cell identifiers from \code{object} -
#' the dataframe contains a single column, "batch", indicating the ...
.FindBatches <- function(object, layers) {
# if an `SCTAssay` is passed in it's expected that the transformation
# was run on each batch individually and then merged so we can use
# the model identifier to split our batches
if (inherits(object, what = "SCTAssay")) {
# build an empty data.frame indexed
# on the cell identifiers from `object`
batch.df <- SeuratObject::EmptyDF(n = ncol(object))
row.names(batch.df) <- Cells(object)
# for each
for (sct.model in levels(object)) {
cell.identifiers <- Cells(object, layer = sct.model)
batch.df[cell.identifiers, "batch"] <- sct.model
}
# otherwise batches can be split using `object`'s layers
} else {
# build a LogMap indicating which layer each cell is from
layer.masks <- slot(object, name = "cells")[, layers]
# get a named vector mapping each cell to its respective layer
layer.map <- labels(
layer.masks,
values = Cells(object, layer = layers)
)
# wrap the vector up in a data.frame
batch.df <- as.data.frame(layer.map)
names(batch.df) <- "batch"
}

return(batch.df)
}
5 changes: 3 additions & 2 deletions man/SeuratWrappers-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

55 changes: 36 additions & 19 deletions man/scVIIntegration.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit d9594f6

Please sign in to comment.