From 63572ad3241b1b05eeecba4e62b371eea85093d1 Mon Sep 17 00:00:00 2001 From: Julien Chiquet Date: Thu, 29 Apr 2021 14:38:27 +0200 Subject: [PATCH] fixing prediction to match blockmodels behavior (#3) even if theire is a slight mistake for the Bernoulli, covariate case. Will be fixed later --- DESCRIPTION | 2 +- R/R6Class-BipartiteSBM.R | 7 +- R/R6Class-SimpleSBM.R | 6 +- R/estimate.R | 1 - tests/testthat/test-BipartiteSBM_fit.R | 30 ++++++++- .../test-BipartiteSBM_fit_covariates.R | 31 ++++++++- tests/testthat/test-SimpleSBM_fit.R | 60 +++++++++++++++-- .../testthat/test-SimpleSBM_fit_covariates.R | 64 ++++++++++++++++--- 8 files changed, 177 insertions(+), 24 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 5147d410..7f951a3f 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: sbm Title: Stochastic Blockmodels -Version: 0.4.0-9210 +Version: 0.4.0-9220 Authors@R: c(person(given = "Julien", family = "Chiquet", role = c("aut", "cre"), email = "julien.chiquet@inrae.fr", diff --git a/R/R6Class-BipartiteSBM.R b/R/R6Class-BipartiteSBM.R index 825e6848..41a09ca6 100644 --- a/R/R6Class-BipartiteSBM.R +++ b/R/R6Class-BipartiteSBM.R @@ -70,7 +70,12 @@ BipartiteSBM <- if (length(covarList) > 0) { stopifnot(all(sapply(covarList, nrow) == self$nbNodes[1]), all(sapply(covarList, ncol) == self$nbNodes[2])) - res <- private$invlink[[1L]](private$Z[[1]] %*% private$link[[1L]]( mu ) %*% t(private$Z[[2]]) + self$covarEffect) + if (self$modelName == "bernoulli") { + res <- private$invlink[[1L]](private$Z[[1]] %*% private$link[[1L]]( mu ) %*% t(private$Z[[2]]) + self$covarEffect) + } else { + res <- private$invlink[[1L]](private$link[[1L]](private$Z[[1]] %*% mu %*% t(private$Z[[2]])) + self$covarEffect) + } + } else { res <- private$Z[[1]] %*% mu %*% t(private$Z[[2]]) } diff --git a/R/R6Class-SimpleSBM.R b/R/R6Class-SimpleSBM.R index d0bd99ac..21b1448d 100644 --- a/R/R6Class-SimpleSBM.R +++ b/R/R6Class-SimpleSBM.R @@ -66,7 +66,11 @@ SimpleSBM <- if (self$nbCovariates > 0) { stopifnot(all(sapply(covarList, nrow) == self$nbNodes, sapply(covarList, ncol) == self$nbNodes)) - res <- private$invlink[[1L]](private$Z %*% private$link[[1L]](mu) %*% t(private$Z) + self$covarEffect) + if (self$modelName == "bernoulli") { + res <- private$invlink[[1L]](private$Z %*% private$link[[1L]]( mu ) %*% t(private$Z) + self$covarEffect) + } else { + res <- private$invlink[[1L]](private$link[[1L]](private$Z %*% mu %*% t(private$Z)) + self$covarEffect) + } } else { res <- private$Z %*% mu %*% t(private$Z) } diff --git a/R/estimate.R b/R/estimate.R index eb177fc8..77a8aa82 100644 --- a/R/estimate.R +++ b/R/estimate.R @@ -98,7 +98,6 @@ estimateSimpleSBM <- function(netMat, fast = TRUE ) - ## Current options are default expect for those passed by the user currentOptions[names(estimOptions)] <- estimOptions diff --git a/tests/testthat/test-BipartiteSBM_fit.R b/tests/testthat/test-BipartiteSBM_fit.R index 95ddf773..78acb6ac 100644 --- a/tests/testthat/test-BipartiteSBM_fit.R +++ b/tests/testthat/test-BipartiteSBM_fit.R @@ -46,7 +46,7 @@ test_that("BipartiteSBM_fit 'Bernoulli' model, undirected, no covariate", { expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam) ## Estimation----------------------------------------------------------------- - mySBM$optimize(estimOptions=list(verbosity = 0)) + BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0)) mySBM$setModel(5) expect_equal(mySBM$nbConnectParam, unname(nbBlocks[1] * nbBlocks[2])) @@ -80,6 +80,14 @@ test_that("BipartiteSBM_fit 'Bernoulli' model, undirected, no covariate", { expect_lt(1 - aricode::ARI(mySBM$memberships[[1]], mySampler$memberships[[1]]), .2) expect_lt(1 - aricode::ARI(mySBM$memberships[[2]], mySampler$memberships[[2]]), .2) + ## prediction wrt BM + for (Q in 2:5) { + pred_bm <- BM_out$prediction(Q = Q) + mySBM$setModel(Q) + pred_sbm <- predict(mySBM) + expect_lt( rmse(pred_bm, pred_sbm), 1e-12) + } + }) test_that("BipartiteSBM_fit 'Poisson' model, undirected, no covariate", { @@ -121,7 +129,7 @@ test_that("BipartiteSBM_fit 'Poisson' model, undirected, no covariate", { expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam) ## Estimation----------------------------------------------------------------- - mySBM$optimize(estimOptions=list(verbosity = 0)) + BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0)) mySBM$setModel(5) ## Expectation @@ -148,6 +156,14 @@ test_that("BipartiteSBM_fit 'Poisson' model, undirected, no covariate", { expect_lt(1 - aricode::ARI(mySBM$memberships[[1]], mySampler$memberships[[1]]), 1e-1) expect_lt(1 - aricode::ARI(mySBM$memberships[[2]], mySampler$memberships[[2]]), 1e-1) + ## prediction wrt BM + for (Q in 2:5) { + pred_bm <- BM_out$prediction(Q = Q) + mySBM$setModel(Q) + pred_sbm <- predict(mySBM) + expect_lt( rmse(pred_bm, pred_sbm), 1e-12) + } + }) test_that("BipartiteSBM_fit 'Gaussian' model, undirected, no covariate", { @@ -189,7 +205,7 @@ test_that("BipartiteSBM_fit 'Gaussian' model, undirected, no covariate", { expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam) ## Estimation----------------------------------------------------------------- - mySBM$optimize(estimOptions=list(verbosity = 0)) + BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0)) mySBM$setModel(5) ## Expectation @@ -215,6 +231,14 @@ test_that("BipartiteSBM_fit 'Gaussian' model, undirected, no covariate", { expect_lt(1 - aricode::ARI(mySBM$memberships[[1]], mySampler$memberships[[1]]), 1e-1) expect_lt(1 - aricode::ARI(mySBM$memberships[[2]], mySampler$memberships[[2]]), 1e-1) + ## prediction wrt BM + for (Q in 2:5) { + pred_bm <- BM_out$prediction(Q = Q) + mySBM$setModel(Q) + pred_sbm <- predict(mySBM) + expect_lt( rmse(pred_bm, pred_sbm), 1e-12) + } + }) test_that("active bindings are working in the class", { diff --git a/tests/testthat/test-BipartiteSBM_fit_covariates.R b/tests/testthat/test-BipartiteSBM_fit_covariates.R index d9dcad0b..20498b0b 100644 --- a/tests/testthat/test-BipartiteSBM_fit_covariates.R +++ b/tests/testthat/test-BipartiteSBM_fit_covariates.R @@ -51,7 +51,7 @@ test_that("BipartiteSBM_fit 'Bernoulli' model, undirected, no covariate", { expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam) ## Estimation----------------------------------------------------------------- - mySBM$optimize(estimOptions = list(verbosity = 0, fast = TRUE)) + BM_out <- mySBM$optimize(estimOptions = list(verbosity = 0, fast = TRUE)) mySBM$setModel(5) ## Expectation @@ -75,6 +75,15 @@ test_that("BipartiteSBM_fit 'Bernoulli' model, undirected, no covariate", { expect_equal(predict(mySBM, covarList[1]), fitted(mySBM)) expect_error(predict(mySBM, covarList)) + ## prediction wrt BM + for (Q in 2:5) { + pred_bm <- BM_out$prediction(Q = Q) + mySBM$setModel(Q) + pred_sbm <- predict(mySBM) + expect_lt( rmse(pred_bm, pred_sbm), 1e-12) + } + + }) test_that("BipartiteSBM_fit 'Poisson' model, undirected, no covariate", { @@ -117,7 +126,7 @@ test_that("BipartiteSBM_fit 'Poisson' model, undirected, no covariate", { expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam) ## Estimation----------------------------------------------------------------- - mySBM$optimize(estimOptions=list(verbosity = 0)) + BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0)) mySBM$setModel(5) ## Expectation @@ -145,6 +154,14 @@ test_that("BipartiteSBM_fit 'Poisson' model, undirected, no covariate", { expect_equal(predict(mySBM, covarList), fitted(mySBM)) expect_error(predict(mySBM, covarList[1])) + ## prediction wrt BM + for (Q in 2:5) { + pred_bm <- BM_out$prediction(Q = Q) + mySBM$setModel(Q) + pred_sbm <- predict(mySBM) + expect_lt( rmse(pred_bm, pred_sbm), 1e-12) + } + }) test_that("BipartiteSBM_fit 'Gaussian' model, undirected, no covariate", { @@ -187,7 +204,7 @@ test_that("BipartiteSBM_fit 'Gaussian' model, undirected, no covariate", { expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam) ## Estimation----------------------------------------------------------------- - mySBM$optimize(estimOptions=list(verbosity = 0)) + BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0)) mySBM$setModel(5) ## Expectation @@ -213,5 +230,13 @@ test_that("BipartiteSBM_fit 'Gaussian' model, undirected, no covariate", { expect_lt(1 - aricode::ARI(mySBM$memberships[[1]], mySampler$memberships[[1]]), 2e-1) expect_lt(1 - aricode::ARI(mySBM$memberships[[2]], mySampler$memberships[[2]]), 2e-1) + ## prediction wrt BM + for (Q in 2:5) { + pred_bm <- BM_out$prediction(Q = Q) + mySBM$setModel(Q) + pred_sbm <- predict(mySBM) + expect_lt( rmse(pred_bm, pred_sbm), 1e-12) + } + }) diff --git a/tests/testthat/test-SimpleSBM_fit.R b/tests/testthat/test-SimpleSBM_fit.R index 95d2eb6f..6ab80f1c 100644 --- a/tests/testthat/test-SimpleSBM_fit.R +++ b/tests/testthat/test-SimpleSBM_fit.R @@ -52,7 +52,7 @@ test_that("SimpleSBM_fit 'Bernoulli' model, undirected, no covariate", { expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam) ## Estimation----------------------------------------------------------------- - mySBM$optimize(estimOptions=list(verbosity = 0)) + BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0)) mySBM$setModel(3) ## Field set after optimization @@ -83,6 +83,14 @@ test_that("SimpleSBM_fit 'Bernoulli' model, undirected, no covariate", { expect_lt(rmse(mySBM$connectParam$mean, means), 0.25) expect_lt(1 - aricode::ARI(mySBM$memberships, mySampler$memberships), 0.25) + ## prediction wrt BM + for (Q in 1:4) { + pred_bm <- BM_out$prediction(Q = Q) + mySBM$setModel(Q) + pred_sbm <- predict(mySBM) + expect_lt( rmse(pred_bm, pred_sbm), 1e-12) + } + }) test_that("SimpleSBM_fit 'Bernoulli' model, directed, no covariate", { @@ -130,7 +138,7 @@ test_that("SimpleSBM_fit 'Bernoulli' model, directed, no covariate", { expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam) ## Estimation----------------------------------------------------------------- - mySBM$optimize(estimOptions=list(verbosity = 0)) + BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0)) mySBM$setModel(3) expect_equal(mySBM$nbConnectParam, nbBlocks * nbBlocks) @@ -160,6 +168,14 @@ test_that("SimpleSBM_fit 'Bernoulli' model, directed, no covariate", { expect_lt(rmse(sort(mySBM$connectParam$mean), means), 0.2) expect_lt(1 - aricode::ARI(mySBM$memberships, mySampler$memberships), 0.2) + ## prediction wrt BM + for (Q in 1:4) { + pred_bm <- BM_out$prediction(Q = Q) + mySBM$setModel(Q) + pred_sbm <- predict(mySBM) + expect_lt( rmse(pred_bm, pred_sbm), 1e-12) + } + }) test_that("SimpleSBM_fit 'Poisson' model, undirected, no covariate", { @@ -207,7 +223,7 @@ test_that("SimpleSBM_fit 'Poisson' model, undirected, no covariate", { expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam) ## Estimation----------------------------------------------------------------- - mySBM$optimize(estimOptions=list(verbosity = 0)) + BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0)) mySBM$setModel(3) ## Expectation @@ -232,6 +248,14 @@ test_that("SimpleSBM_fit 'Poisson' model, undirected, no covariate", { expect_lt(rmse(mySBM$connectParam$mean, means), 1e-1) expect_lt(1 - aricode::ARI(mySBM$memberships, mySampler$memberships), 1e-1) + ## prediction wrt BM + for (Q in 1:4) { + pred_bm <- BM_out$prediction(Q = Q) + mySBM$setModel(Q) + pred_sbm <- predict(mySBM) + expect_lt( rmse(pred_bm, pred_sbm), 1e-12) + } + }) test_that("SimpleSBM_fit 'Poisson' model, directed, no covariate", { @@ -278,7 +302,7 @@ test_that("SimpleSBM_fit 'Poisson' model, directed, no covariate", { expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam) ## Estimation----------------------------------------------------------------- - mySBM$optimize(estimOptions=list(verbosity = 0)) + BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0)) mySBM$setModel(3) ## Expectation @@ -303,6 +327,14 @@ test_that("SimpleSBM_fit 'Poisson' model, directed, no covariate", { expect_lt(rmse(sort(mySBM$connectParam$mean), means), 1e-1) expect_lt(1 - aricode::ARI(mySBM$memberships, mySampler$memberships), 1e-1) + ## prediction wrt BM + for (Q in 1:4) { + pred_bm <- BM_out$prediction(Q = Q) + mySBM$setModel(Q) + pred_sbm <- predict(mySBM) + expect_lt( rmse(pred_bm, pred_sbm), 1e-12) + } + }) @@ -350,7 +382,7 @@ test_that("SimpleSBM_fit 'Gaussian' model, undirected, no covariate", { expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam) ## Estimation----------------------------------------------------------------- - mySBM$optimize(estimOptions=list(verbosity = 0)) + BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0)) mySBM$setModel(3) ## Expectation @@ -374,6 +406,14 @@ test_that("SimpleSBM_fit 'Gaussian' model, undirected, no covariate", { expect_lt(rmse(mySBM$connectParam$mean, means), 1e-1) expect_lt(1 - aricode::ARI(mySBM$memberships, mySampler$memberships), 1e-1) + ## prediction wrt BM + for (Q in 1:4) { + pred_bm <- BM_out$prediction(Q = Q) + mySBM$setModel(Q) + pred_sbm <- predict(mySBM) + expect_lt( rmse(pred_bm, pred_sbm), 1e-12) + } + }) test_that("SimpleSBM_fit 'Gaussian' model, undirected, no covariate", { @@ -421,7 +461,7 @@ test_that("SimpleSBM_fit 'Gaussian' model, undirected, no covariate", { expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam) ## Estimation----------------------------------------------------------------- - mySBM$optimize(estimOptions=list(verbosity = 0)) + BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0)) mySBM$setModel(3) ## Expectation @@ -445,6 +485,14 @@ test_that("SimpleSBM_fit 'Gaussian' model, undirected, no covariate", { expect_lt(rmse(sort(mySBM$connectParam$mean), means), 1e-1) expect_lt(1 - aricode::ARI(mySBM$memberships, mySampler$memberships), 1e-1) + ## prediction wrt BM + for (Q in 1:4) { + pred_bm <- BM_out$prediction(Q = Q) + mySBM$setModel(Q) + pred_sbm <- predict(mySBM) + expect_lt( rmse(pred_bm, pred_sbm), 1e-12) + } + }) test_that("active binding are working in the class", { diff --git a/tests/testthat/test-SimpleSBM_fit_covariates.R b/tests/testthat/test-SimpleSBM_fit_covariates.R index e88f6170..9e02f667 100644 --- a/tests/testthat/test-SimpleSBM_fit_covariates.R +++ b/tests/testthat/test-SimpleSBM_fit_covariates.R @@ -62,7 +62,7 @@ test_that("SimpleSBM_fit 'Bernoulli' model, undirected, one covariate", { expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam) ## Estimation----------------------------------------------------------------- - mySBM$optimize(estimOptions = list(verbosity = 0, fast = TRUE)) + BM_out <- mySBM$optimize(estimOptions = list(verbosity = 0, fast = TRUE)) mySBM$setModel(2) ## Expectation @@ -86,6 +86,14 @@ test_that("SimpleSBM_fit 'Bernoulli' model, undirected, one covariate", { expect_equal(predict(mySBM, covarList[1]), fitted(mySBM)) expect_error(predict(mySBM, covarList)) + ## prediction wrt BM + for (Q in 1:3) { + pred_bm <- BM_out$prediction(Q = Q) + mySBM$setModel(Q) + pred_sbm <- predict(mySBM) + expect_lt( rmse(pred_bm, pred_sbm), 1e-12) + } + }) test_that("SimpleSBM_fit 'Bernoulli' model, directed, one covariate", { @@ -133,7 +141,7 @@ test_that("SimpleSBM_fit 'Bernoulli' model, directed, one covariate", { expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam) ## Estimation----------------------------------------------------------------- - mySBM$optimize(estimOptions = list(verbosity = 0, fast = TRUE)) + BM_out <- mySBM$optimize(estimOptions = list(verbosity = 0, fast = TRUE)) mySBM$setModel(2) ## Expectation @@ -157,6 +165,14 @@ test_that("SimpleSBM_fit 'Bernoulli' model, directed, one covariate", { expect_equal(predict(mySBM, covarList[1]), fitted(mySBM)) + ## prediction wrt BM + for (Q in 1:3) { + pred_bm <- BM_out$prediction(Q = Q) + mySBM$setModel(Q) + pred_sbm <- predict(mySBM) + expect_lt( rmse(pred_bm, pred_sbm), 1e-12) + } + }) test_that("SimpleSBM_fit 'Poisson' model, undirected, two covariates", { @@ -193,7 +209,7 @@ test_that("SimpleSBM_fit 'Poisson' model, undirected, two covariates", { expect_true(is.matrix(mySBM$connectParam$mean)) ## covariates - expect_true(all(dim(mySBM$covarEffect) == c(nbNodes, nbNodes))) + expect_true(all(dim(mySBM$covarEffect) == c(nbNodes, nbNodes))) expect_equal(mySBM$nbCovariates, 1) expect_equal(mySBM$covarList, covarList[1]) expect_equal(mySBM$covarParam, c(0)) @@ -204,7 +220,7 @@ test_that("SimpleSBM_fit 'Poisson' model, undirected, two covariates", { expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam) ## Estimation----------------------------------------------------------------- - mySBM$optimize(estimOptions=list(verbosity = 0)) + BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0)) mySBM$setModel(2) ## Expectation @@ -227,6 +243,14 @@ test_that("SimpleSBM_fit 'Poisson' model, undirected, two covariates", { expect_equal(predict(mySBM, covarList[1]), fitted(mySBM)) + ## prediction wrt BM + for (Q in 1:3) { + pred_bm <- BM_out$prediction(Q = Q) + mySBM$setModel(Q) + pred_sbm <- predict(mySBM) + expect_lt( rmse(pred_bm, pred_sbm), 1e-12) + } + }) test_that("SimpleSBM_fit 'Poisson' model, directed, two covariates", { @@ -274,7 +298,7 @@ test_that("SimpleSBM_fit 'Poisson' model, directed, two covariates", { expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam) ## Estimation----------------------------------------------------------------- - mySBM$optimize(estimOptions=list(verbosity = 0)) + BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0)) mySBM$setModel(2) ## Expectation @@ -297,6 +321,14 @@ test_that("SimpleSBM_fit 'Poisson' model, directed, two covariates", { expect_equal(predict(mySBM, covarList[1]), fitted(mySBM)) + ## prediction wrt BM + for (Q in 1:3) { + pred_bm <- BM_out$prediction(Q = Q) + mySBM$setModel(Q) + pred_sbm <- predict(mySBM) + expect_lt( rmse(pred_bm, pred_sbm), 1e-12) + } + }) @@ -334,7 +366,7 @@ test_that("SimpleSBM_fit 'Gaussian' model, undirected, two covariates", { expect_true(is.matrix(mySBM$connectParam$mean)) ## covariates - expect_true(all(dim(mySBM$covarEffect) == c(nbNodes, nbNodes))) + expect_true(all(dim(mySBM$covarEffect) == c(nbNodes, nbNodes))) expect_equal(mySBM$nbCovariates, 2) expect_equal(mySBM$covarList, covarList) expect_equal(mySBM$covarParam, c(0,0)) @@ -345,7 +377,7 @@ test_that("SimpleSBM_fit 'Gaussian' model, undirected, two covariates", { expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam) ## Estimation----------------------------------------------------------------- - mySBM$optimize(estimOptions=list(verbosity = 0)) + BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0)) mySBM$setModel(2) ## Expectation @@ -372,6 +404,14 @@ test_that("SimpleSBM_fit 'Gaussian' model, undirected, two covariates", { expect_equal(predict(mySBM, covarList), fitted(mySBM)) + ## prediction wrt BM + for (Q in 1:3) { + pred_bm <- BM_out$prediction(Q = Q) + mySBM$setModel(Q) + pred_sbm <- predict(mySBM) + expect_lt( rmse(pred_bm, pred_sbm), 1e-12) + } + }) @@ -420,7 +460,7 @@ test_that("SimpleSBM_fit 'Gaussian' model, undirected, two covariates", { expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam) ## Estimation----------------------------------------------------------------- - mySBM$optimize(estimOptions=list(verbosity = 0)) + BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0)) mySBM$setModel(2) ## Expectation @@ -447,4 +487,12 @@ test_that("SimpleSBM_fit 'Gaussian' model, undirected, two covariates", { expect_equal(predict(mySBM, covarList), fitted(mySBM)) + ## prediction wrt BM + for (Q in 1:3) { + pred_bm <- BM_out$prediction(Q = Q) + mySBM$setModel(Q) + pred_sbm <- predict(mySBM) + expect_lt( rmse(pred_bm, pred_sbm), 1e-12) + } + })