Skip to content

Commit

Permalink
Merge pull request #43 from EvolEcolGroup/projectPCA
Browse files Browse the repository at this point in the history
Project pca
  • Loading branch information
dramanica authored Jun 11, 2024
2 parents 1aec50d + 129bc16 commit 39129c3
Show file tree
Hide file tree
Showing 10 changed files with 246 additions and 20 deletions.
4 changes: 4 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

fbm256_prod_and_rowSumsSq <- function(BM, ind_row, ind_col, center, scale, V) {
.Call(`_tidypopgen_fbm256_prod_and_rowSumsSq`, BM, ind_row, ind_col, center, scale, V)
}

gt_grouped_alt_freq_diploid <- function(BM, rowInd, colInd, groupIds, ngroups, ncores) {
.Call(`_tidypopgen_gt_grouped_alt_freq_diploid`, BM, rowInd, colInd, groupIds, ngroups, ncores)
}
Expand Down
2 changes: 2 additions & 0 deletions R/gt_pca_autoSVD.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ gt_pca_autoSVD <- function(x, k = 10,

this_svd$method <- "autoSVD"
this_svd$call <- match.call()
# subset the loci table to have only the snps of interest
this_svd$loci <- show_loci(x)[.gt_bigsnp_cols %in% attr(x,"subset"),]
class(this_svd) <- c("gt_pca", class(this_svd))
this_svd
}
1 change: 1 addition & 0 deletions R/gt_pca_partialSVD.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ gt_pca_partialSVD <- function(x, k = 10, fun_scaling = bigsnpr::snp_scaleBinom()
rownames(this_svd$v) <- loci_names(x)
this_svd$method <- "partialSVD"
this_svd$call <- match.call()
this_svd$loci <- show_loci(x)
class(this_svd) <- c("gt_pca", class(this_svd))
this_svd
}
1 change: 1 addition & 0 deletions R/gt_pca_randomSVD.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ gt_pca_randomSVD <- function(x, k = 10,
rownames(this_svd$v) <- loci_names(x)
this_svd$method <- "randomSVD"
this_svd$call <- match.call()
this_svd$loci <- show_loci(x)
class(this_svd) <- c("gt_pca", class(this_svd))
this_svd
}
119 changes: 103 additions & 16 deletions R/predict_gt_pca.R
Original file line number Diff line number Diff line change
@@ -1,41 +1,128 @@
#' Predict scores of a PCA
#'
#' Predict the PCA scores for a [`gt_pca`], either for the original data or for new data.
#' Predict the PCA scores for a [`gt_pca`], either for the original data or
#' projecting new data.
#'
#' @param object the [`gt_pca`] object
#' @param new_data a gen_tibble if scores are requested for a new dataset
#' @param impute_to_center boolean on whether to impute missing values in
#' `new_data` to the mean values used
#' to center the pca. This option is used to e.g. project ancient data onto a PCA
#' fitted to modern data. It defaults to TRUE.
#' @param prediction_type a string taking the value of "simple" and/or OADP (Online Augmentation, Decomposition, and Procrustes (OADP) projection)
#' @param block_size number of loci read simultaneously (larger values will speed up
#' computation, but require more memory)
#' @param n_cores number of cores
#' @param ... no used
#' @returns a matrix of predictions, with samples as rows and components as columns. The number
#' of components depends on how many were estimated in the [`gt_pca`] object
#' of components depends on how many were estimated in the [`gt_pca`] object. If prediction
#' type is equal to c("simple","OADP"), then a list of two matrices is returned
#' @references Zhang et al (2020). Fast and robust ancestry prediction using
#' principal component analysis 36(11): 3439–3446.
#' @rdname predict_gt_pca
#' @export

# this is a modified version of bigstatsr::predict.big_SVD
predict.gt_pca <- function(object, new_data=NULL,block_size = NULL, ...){
predict.gt_pca <- function(object, new_data=NULL, impute_to_center = TRUE,
prediction_type = "simple",
block_size = NULL,
n_cores = 1, ...){
rlang::check_dots_empty()
if (!all(prediction_type %in% c("simple", "OADP"))){
stop("prediction_type can only take values 'simple' or 'OADP'")
}

if (is.null(new_data)) {
# U * D
sweep(object$u, 2, object$d, '*')
} else {
if (!inherits(new_data,"gen_tbl")){
stop ("new_data should be a gen_tibble")
}
if (gt_has_imputed(new_data) && !gt_uses_imputed(new_data)){
gt_set_imputed(new_data, set = TRUE)
on.exit(gt_set_imputed(new_data, set = FALSE))
# check the new_data have the same loci as the dataset used to build the pca
if (!all(object$loci$name %in% show_loci(new_data)$name)){
stop("loci used in object are not present in new_data")
}
if (is.null(block_size)){
block_size <- bigstatsr::block_size(nrow(new_data))
# get id of loci in new_data
loci_subset <- match(object$loci$name, show_loci(new_data)$name)
if (!all(all(show_loci(new_data)$allele_ref[loci_subset]==object$loci$allele_ref),
all(show_loci(new_data)$allele_alt[loci_subset]==object$loci$allele_alt))){
stop("ref and alt alleles differ between new_data and the data used to create the pca object")
}

if (!impute_to_center){
if (gt_has_imputed(new_data) && !gt_uses_imputed(new_data)){
gt_set_imputed(new_data, set = TRUE)
on.exit(gt_set_imputed(new_data, set = FALSE))
}

if (is.null(block_size)){
block_size <- bigstatsr::block_size(nrow(new_data))
}
# X * V
XV <- bigstatsr::big_prodMat(.gt_get_bigsnp(new_data)$genotypes,
object$v,
ind.row = .gt_bigsnp_rows(new_data),
ind.col = .gt_bigsnp_cols(new_data)[loci_subset],
block.size = block_size,
center = object$center,
scale = object$scale)
# if we use OADP, then we need to compute Xnorm
if ("OADP" %in% prediction_type){
stop ("OADP currently only implemented for when `impute_to_center = TRUE`")
}
} else {

X_norm <- bigstatsr::FBM(nrow(new_data), 1, init = 0)
XV <- bigstatsr::FBM(nrow(new_data), ncol(object$u), init = 0)

bigstatsr::big_parallelize(
.gt_get_bigsnp(new_data)$genotypes,
p.FUN = fbm256_part_prod,
ind = seq_along(loci_subset),
ncores = n_cores,
ind.row = .gt_bigsnp_rows(new_data),
ind.col = .gt_bigsnp_cols(new_data)[loci_subset], #info_snp$`_NUM_ID_`[keep],
center = object$center,
scale = object$scale,
V = object$v,
XV = XV,
X_norm = X_norm
)

if ("OADP" %in% prediction_type){
oadp_proj <- utils::getFromNamespace("OADP_proj", "bigsnpr")(XV, X_norm, object$d, ncores = n_cores)
}
}
# X * V
bigstatsr::big_prodMat(.gt_get_bigsnp(new_data)$genotypes,
object$v,
ind.row = .gt_bigsnp_rows(new_data),
ind.col = .gt_bigsnp_cols(new_data),
block.size = block_size,
center = object$center,
scale = object$scale)
if (all(c("simple","OADP") %in% prediction_type)){
return( list(
simple_proj = XV[, , drop = FALSE],
OADP_proj = oadp_proj
))
} else if ("simple" %in% prediction_type) {
return(XV[,, drop = FALSE])
} else {
return (oadp_proj)
}
return(XV[,, drop = FALSE])
}
}

#######################################################################################
# a port of bigsnpr::part_prod to work on standard fb256 matrices

fbm256_part_prod <- function(X, ind, ind.row, ind.col, center, scale, V, XV, X_norm) {

res <- fbm256_prod_and_rowSumsSq(
BM = X,
ind_row = ind.row,
ind_col = ind.col[ind],
center = center[ind],
scale = scale[ind],
V = V[ind, , drop = FALSE]
)

bigstatsr::big_increment(XV, res[[1]], use_lock = TRUE)
bigstatsr::big_increment(X_norm, res[[2]], use_lock = TRUE)
}

29 changes: 26 additions & 3 deletions man/predict_gt_pca.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,22 @@ Rcpp::Rostream<true>& Rcpp::Rcout = Rcpp::Rcpp_cout_get();
Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
#endif

// fbm256_prod_and_rowSumsSq
List fbm256_prod_and_rowSumsSq(Environment BM, const IntegerVector& ind_row, const IntegerVector& ind_col, const NumericVector& center, const NumericVector& scale, const NumericMatrix& V);
RcppExport SEXP _tidypopgen_fbm256_prod_and_rowSumsSq(SEXP BMSEXP, SEXP ind_rowSEXP, SEXP ind_colSEXP, SEXP centerSEXP, SEXP scaleSEXP, SEXP VSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< Environment >::type BM(BMSEXP);
Rcpp::traits::input_parameter< const IntegerVector& >::type ind_row(ind_rowSEXP);
Rcpp::traits::input_parameter< const IntegerVector& >::type ind_col(ind_colSEXP);
Rcpp::traits::input_parameter< const NumericVector& >::type center(centerSEXP);
Rcpp::traits::input_parameter< const NumericVector& >::type scale(scaleSEXP);
Rcpp::traits::input_parameter< const NumericMatrix& >::type V(VSEXP);
rcpp_result_gen = Rcpp::wrap(fbm256_prod_and_rowSumsSq(BM, ind_row, ind_col, center, scale, V));
return rcpp_result_gen;
END_RCPP
}
// gt_grouped_alt_freq_diploid
ListOf<NumericMatrix> gt_grouped_alt_freq_diploid(Environment BM, const IntegerVector& rowInd, const IntegerVector& colInd, const IntegerVector& groupIds, int ngroups, int ncores);
RcppExport SEXP _tidypopgen_gt_grouped_alt_freq_diploid(SEXP BMSEXP, SEXP rowIndSEXP, SEXP colIndSEXP, SEXP groupIdsSEXP, SEXP ngroupsSEXP, SEXP ncoresSEXP) {
Expand Down Expand Up @@ -171,6 +187,7 @@ END_RCPP
}

static const R_CallMethodDef CallEntries[] = {
{"_tidypopgen_fbm256_prod_and_rowSumsSq", (DL_FUNC) &_tidypopgen_fbm256_prod_and_rowSumsSq, 6},
{"_tidypopgen_gt_grouped_alt_freq_diploid", (DL_FUNC) &_tidypopgen_gt_grouped_alt_freq_diploid, 6},
{"_tidypopgen_gt_grouped_alt_freq_pseudohap", (DL_FUNC) &_tidypopgen_gt_grouped_alt_freq_pseudohap, 7},
{"_tidypopgen_gt_grouped_missingness", (DL_FUNC) &_tidypopgen_gt_grouped_missingness, 6},
Expand Down
48 changes: 48 additions & 0 deletions src/fbm_prod_and_rowSumSq.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/******************************************************************************/

#include <bigstatsr/BMCodeAcc.h>

/******************************************************************************/

// This is a port of prod_and_rowSumSq from bigsnpr, adapted to work on a standard FBM

// [[Rcpp::export]]
List fbm256_prod_and_rowSumsSq(Environment BM,
const IntegerVector& ind_row,
const IntegerVector& ind_col,
const NumericVector& center,
const NumericVector& scale,
const NumericMatrix& V) {

XPtr<FBM> xpBM = BM["address"];
SubBMCode256Acc macc(xpBM, ind_row, ind_col, BM["code256"], 1);


size_t n = macc.nrow(); //Number of individuals
size_t m = macc.ncol(); //Number of sites
myassert_size(m, V.rows()); //check number of sites same as number in V from PCA
size_t K = V.cols();
size_t i, j, k;

NumericMatrix XV(n, K);
NumericVector rowSumsSq(n);

for (j = 0; j < m; j++) {
for (i = 0; i < n; i++) {
double x = macc(i, j); // here we need to center and standardise
if (x>-1){
x = (x-center[j])/scale[j];
} else {
// Rcout<<"impute"<<std::endl;
x = 0;
}
rowSumsSq[i] += x*x;
for (k = 0; k < K; k++) {
XV(i, k) += x * V(j, k);
}
}
}

return List::create(XV, rowSumsSq);
}

3 changes: 3 additions & 0 deletions tests/testthat/test_gt_extract_f2.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ test_gt <- gen_tibble(x = test_genotypes,
quiet = TRUE)
test_gt <- test_gt %>% group_by(population)

# TODO I don't understand why this does not silence all messages
options("rlib_message_verbosity" = "quiet")

test_that("extract f2 correctly",{
# process the data with admixtools (note that we get some warnings)
bed_file <- gt_as_plink(test_gt, file = tempfile("test_bed"))
Expand Down
42 changes: 41 additions & 1 deletion tests/testthat/test_gt_pca.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,48 @@ test_that("fit_gt_pca_and_predict",{
"You can't have missing")
missing_gt <- gt_impute_simple(missing_gt, method = "mode")
missing_pca <- missing_gt %>% gt_pca_partialSVD()
# check that predicting on the object is the same as predicting from the full dataset
# without imputation to the center (the data are already imputed)
expect_true(all.equal(predict(missing_pca),
predict(missing_pca, new_data = missing_gt),
predict(missing_pca, new_data = missing_gt,
impute_to_center = FALSE),
check.attributes=FALSE))
# now mismatch the loci table
missing_gt_edited <- missing_gt
show_loci(missing_gt_edited)$name[3] <- "blah"
expect_error(predict(missing_pca, new_data = missing_gt_edited),
"loci used in object")
missing_gt_edited <- missing_gt
show_loci(missing_gt_edited)$allele_ref[3] <- "blah"
expect_error(predict(missing_pca, new_data = missing_gt_edited),
"ref and alt alleles differ")
# predict when new dataset has extra positions
missing_gt_sub <- missing_gt %>% select_loci(100:450)
missing_sub_pca <- missing_gt_sub %>% gt_pca_partialSVD()
expect_true(all.equal(predict(missing_sub_pca),
predict(missing_sub_pca, new_data = missing_gt,
impute_to_center = FALSE),
check.attributes=FALSE))

})

# TODO we should test gt_pca_autoSVD(), as the loci have to be subset within
# the object

test_that("fit_gt_pca_and_predict_splitted_data",{
bed_file <- system.file("extdata", "example-missing.bed", package = "bigsnpr")
missing_gt <- gen_tibble(bed_file, backingfile = tempfile("missing_"),quiet=TRUE)
# create a fake ancient set by subsetting
ancient_gt <- missing_gt[1:20,]
# now extract the modern data (to be imputed)
modern_gt <- missing_gt[-c(1:20),]

modern_gt <- gt_impute_simple(modern_gt, method = "mode")
modern_pca <- modern_gt %>% gt_pca_partialSVD()
# if we just try to predict, we find that the new data have missing data
ancient_pred <- predict(modern_pca, new_data = ancient_gt)
expect_true(all(dim(ancient_pred)==c(20,10)))
# now raise an error if we don't impute to the mean
expect_error(predict(modern_pca, new_data = ancient_gt, impute_to_center = FALSE),
"You can't have missing values in 'X'")
})

0 comments on commit 39129c3

Please sign in to comment.