Skip to content

Commit

Permalink
fixing prediction to match blockmodels behavior (#3) even if theire i…
Browse files Browse the repository at this point in the history
…s a slight mistake for the Bernoulli, covariate case. Will be fixed later
  • Loading branch information
jchiquet committed Apr 29, 2021
1 parent 8c50c3c commit 63572ad
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 24 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: sbm
Title: Stochastic Blockmodels
Version: 0.4.0-9210
Version: 0.4.0-9220
Authors@R:
c(person(given = "Julien", family = "Chiquet", role = c("aut", "cre"),
email = "[email protected]",
Expand Down
7 changes: 6 additions & 1 deletion R/R6Class-BipartiteSBM.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,12 @@ BipartiteSBM <-
if (length(covarList) > 0) {
stopifnot(all(sapply(covarList, nrow) == self$nbNodes[1]),
all(sapply(covarList, ncol) == self$nbNodes[2]))
res <- private$invlink[[1L]](private$Z[[1]] %*% private$link[[1L]]( mu ) %*% t(private$Z[[2]]) + self$covarEffect)
if (self$modelName == "bernoulli") {
res <- private$invlink[[1L]](private$Z[[1]] %*% private$link[[1L]]( mu ) %*% t(private$Z[[2]]) + self$covarEffect)
} else {
res <- private$invlink[[1L]](private$link[[1L]](private$Z[[1]] %*% mu %*% t(private$Z[[2]])) + self$covarEffect)
}

} else {
res <- private$Z[[1]] %*% mu %*% t(private$Z[[2]])
}
Expand Down
6 changes: 5 additions & 1 deletion R/R6Class-SimpleSBM.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ SimpleSBM <-
if (self$nbCovariates > 0) {
stopifnot(all(sapply(covarList, nrow) == self$nbNodes,
sapply(covarList, ncol) == self$nbNodes))
res <- private$invlink[[1L]](private$Z %*% private$link[[1L]](mu) %*% t(private$Z) + self$covarEffect)
if (self$modelName == "bernoulli") {
res <- private$invlink[[1L]](private$Z %*% private$link[[1L]]( mu ) %*% t(private$Z) + self$covarEffect)
} else {
res <- private$invlink[[1L]](private$link[[1L]](private$Z %*% mu %*% t(private$Z)) + self$covarEffect)
}
} else {
res <- private$Z %*% mu %*% t(private$Z)
}
Expand Down
1 change: 0 additions & 1 deletion R/estimate.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ estimateSimpleSBM <- function(netMat,
fast = TRUE
)


## Current options are default expect for those passed by the user
currentOptions[names(estimOptions)] <- estimOptions

Expand Down
30 changes: 27 additions & 3 deletions tests/testthat/test-BipartiteSBM_fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ test_that("BipartiteSBM_fit 'Bernoulli' model, undirected, no covariate", {
expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam)

## Estimation-----------------------------------------------------------------
mySBM$optimize(estimOptions=list(verbosity = 0))
BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0))
mySBM$setModel(5)

expect_equal(mySBM$nbConnectParam, unname(nbBlocks[1] * nbBlocks[2]))
Expand Down Expand Up @@ -80,6 +80,14 @@ test_that("BipartiteSBM_fit 'Bernoulli' model, undirected, no covariate", {
expect_lt(1 - aricode::ARI(mySBM$memberships[[1]], mySampler$memberships[[1]]), .2)
expect_lt(1 - aricode::ARI(mySBM$memberships[[2]], mySampler$memberships[[2]]), .2)

## prediction wrt BM
for (Q in 2:5) {
pred_bm <- BM_out$prediction(Q = Q)
mySBM$setModel(Q)
pred_sbm <- predict(mySBM)
expect_lt( rmse(pred_bm, pred_sbm), 1e-12)
}

})

test_that("BipartiteSBM_fit 'Poisson' model, undirected, no covariate", {
Expand Down Expand Up @@ -121,7 +129,7 @@ test_that("BipartiteSBM_fit 'Poisson' model, undirected, no covariate", {
expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam)

## Estimation-----------------------------------------------------------------
mySBM$optimize(estimOptions=list(verbosity = 0))
BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0))
mySBM$setModel(5)

## Expectation
Expand All @@ -148,6 +156,14 @@ test_that("BipartiteSBM_fit 'Poisson' model, undirected, no covariate", {
expect_lt(1 - aricode::ARI(mySBM$memberships[[1]], mySampler$memberships[[1]]), 1e-1)
expect_lt(1 - aricode::ARI(mySBM$memberships[[2]], mySampler$memberships[[2]]), 1e-1)

## prediction wrt BM
for (Q in 2:5) {
pred_bm <- BM_out$prediction(Q = Q)
mySBM$setModel(Q)
pred_sbm <- predict(mySBM)
expect_lt( rmse(pred_bm, pred_sbm), 1e-12)
}

})

test_that("BipartiteSBM_fit 'Gaussian' model, undirected, no covariate", {
Expand Down Expand Up @@ -189,7 +205,7 @@ test_that("BipartiteSBM_fit 'Gaussian' model, undirected, no covariate", {
expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam)

## Estimation-----------------------------------------------------------------
mySBM$optimize(estimOptions=list(verbosity = 0))
BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0))
mySBM$setModel(5)

## Expectation
Expand All @@ -215,6 +231,14 @@ test_that("BipartiteSBM_fit 'Gaussian' model, undirected, no covariate", {
expect_lt(1 - aricode::ARI(mySBM$memberships[[1]], mySampler$memberships[[1]]), 1e-1)
expect_lt(1 - aricode::ARI(mySBM$memberships[[2]], mySampler$memberships[[2]]), 1e-1)

## prediction wrt BM
for (Q in 2:5) {
pred_bm <- BM_out$prediction(Q = Q)
mySBM$setModel(Q)
pred_sbm <- predict(mySBM)
expect_lt( rmse(pred_bm, pred_sbm), 1e-12)
}

})

test_that("active bindings are working in the class", {
Expand Down
31 changes: 28 additions & 3 deletions tests/testthat/test-BipartiteSBM_fit_covariates.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ test_that("BipartiteSBM_fit 'Bernoulli' model, undirected, no covariate", {
expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam)

## Estimation-----------------------------------------------------------------
mySBM$optimize(estimOptions = list(verbosity = 0, fast = TRUE))
BM_out <- mySBM$optimize(estimOptions = list(verbosity = 0, fast = TRUE))
mySBM$setModel(5)

## Expectation
Expand All @@ -75,6 +75,15 @@ test_that("BipartiteSBM_fit 'Bernoulli' model, undirected, no covariate", {
expect_equal(predict(mySBM, covarList[1]), fitted(mySBM))
expect_error(predict(mySBM, covarList))

## prediction wrt BM
for (Q in 2:5) {
pred_bm <- BM_out$prediction(Q = Q)
mySBM$setModel(Q)
pred_sbm <- predict(mySBM)
expect_lt( rmse(pred_bm, pred_sbm), 1e-12)
}


})

test_that("BipartiteSBM_fit 'Poisson' model, undirected, no covariate", {
Expand Down Expand Up @@ -117,7 +126,7 @@ test_that("BipartiteSBM_fit 'Poisson' model, undirected, no covariate", {
expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam)

## Estimation-----------------------------------------------------------------
mySBM$optimize(estimOptions=list(verbosity = 0))
BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0))
mySBM$setModel(5)

## Expectation
Expand Down Expand Up @@ -145,6 +154,14 @@ test_that("BipartiteSBM_fit 'Poisson' model, undirected, no covariate", {
expect_equal(predict(mySBM, covarList), fitted(mySBM))
expect_error(predict(mySBM, covarList[1]))

## prediction wrt BM
for (Q in 2:5) {
pred_bm <- BM_out$prediction(Q = Q)
mySBM$setModel(Q)
pred_sbm <- predict(mySBM)
expect_lt( rmse(pred_bm, pred_sbm), 1e-12)
}

})

test_that("BipartiteSBM_fit 'Gaussian' model, undirected, no covariate", {
Expand Down Expand Up @@ -187,7 +204,7 @@ test_that("BipartiteSBM_fit 'Gaussian' model, undirected, no covariate", {
expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam)

## Estimation-----------------------------------------------------------------
mySBM$optimize(estimOptions=list(verbosity = 0))
BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0))
mySBM$setModel(5)

## Expectation
Expand All @@ -213,5 +230,13 @@ test_that("BipartiteSBM_fit 'Gaussian' model, undirected, no covariate", {
expect_lt(1 - aricode::ARI(mySBM$memberships[[1]], mySampler$memberships[[1]]), 2e-1)
expect_lt(1 - aricode::ARI(mySBM$memberships[[2]], mySampler$memberships[[2]]), 2e-1)

## prediction wrt BM
for (Q in 2:5) {
pred_bm <- BM_out$prediction(Q = Q)
mySBM$setModel(Q)
pred_sbm <- predict(mySBM)
expect_lt( rmse(pred_bm, pred_sbm), 1e-12)
}

})

60 changes: 54 additions & 6 deletions tests/testthat/test-SimpleSBM_fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ test_that("SimpleSBM_fit 'Bernoulli' model, undirected, no covariate", {
expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam)

## Estimation-----------------------------------------------------------------
mySBM$optimize(estimOptions=list(verbosity = 0))
BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0))
mySBM$setModel(3)

## Field set after optimization
Expand Down Expand Up @@ -83,6 +83,14 @@ test_that("SimpleSBM_fit 'Bernoulli' model, undirected, no covariate", {
expect_lt(rmse(mySBM$connectParam$mean, means), 0.25)
expect_lt(1 - aricode::ARI(mySBM$memberships, mySampler$memberships), 0.25)

## prediction wrt BM
for (Q in 1:4) {
pred_bm <- BM_out$prediction(Q = Q)
mySBM$setModel(Q)
pred_sbm <- predict(mySBM)
expect_lt( rmse(pred_bm, pred_sbm), 1e-12)
}

})

test_that("SimpleSBM_fit 'Bernoulli' model, directed, no covariate", {
Expand Down Expand Up @@ -130,7 +138,7 @@ test_that("SimpleSBM_fit 'Bernoulli' model, directed, no covariate", {
expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam)

## Estimation-----------------------------------------------------------------
mySBM$optimize(estimOptions=list(verbosity = 0))
BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0))
mySBM$setModel(3)

expect_equal(mySBM$nbConnectParam, nbBlocks * nbBlocks)
Expand Down Expand Up @@ -160,6 +168,14 @@ test_that("SimpleSBM_fit 'Bernoulli' model, directed, no covariate", {
expect_lt(rmse(sort(mySBM$connectParam$mean), means), 0.2)
expect_lt(1 - aricode::ARI(mySBM$memberships, mySampler$memberships), 0.2)

## prediction wrt BM
for (Q in 1:4) {
pred_bm <- BM_out$prediction(Q = Q)
mySBM$setModel(Q)
pred_sbm <- predict(mySBM)
expect_lt( rmse(pred_bm, pred_sbm), 1e-12)
}

})

test_that("SimpleSBM_fit 'Poisson' model, undirected, no covariate", {
Expand Down Expand Up @@ -207,7 +223,7 @@ test_that("SimpleSBM_fit 'Poisson' model, undirected, no covariate", {
expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam)

## Estimation-----------------------------------------------------------------
mySBM$optimize(estimOptions=list(verbosity = 0))
BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0))
mySBM$setModel(3)

## Expectation
Expand All @@ -232,6 +248,14 @@ test_that("SimpleSBM_fit 'Poisson' model, undirected, no covariate", {
expect_lt(rmse(mySBM$connectParam$mean, means), 1e-1)
expect_lt(1 - aricode::ARI(mySBM$memberships, mySampler$memberships), 1e-1)

## prediction wrt BM
for (Q in 1:4) {
pred_bm <- BM_out$prediction(Q = Q)
mySBM$setModel(Q)
pred_sbm <- predict(mySBM)
expect_lt( rmse(pred_bm, pred_sbm), 1e-12)
}

})

test_that("SimpleSBM_fit 'Poisson' model, directed, no covariate", {
Expand Down Expand Up @@ -278,7 +302,7 @@ test_that("SimpleSBM_fit 'Poisson' model, directed, no covariate", {
expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam)

## Estimation-----------------------------------------------------------------
mySBM$optimize(estimOptions=list(verbosity = 0))
BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0))
mySBM$setModel(3)

## Expectation
Expand All @@ -303,6 +327,14 @@ test_that("SimpleSBM_fit 'Poisson' model, directed, no covariate", {
expect_lt(rmse(sort(mySBM$connectParam$mean), means), 1e-1)
expect_lt(1 - aricode::ARI(mySBM$memberships, mySampler$memberships), 1e-1)

## prediction wrt BM
for (Q in 1:4) {
pred_bm <- BM_out$prediction(Q = Q)
mySBM$setModel(Q)
pred_sbm <- predict(mySBM)
expect_lt( rmse(pred_bm, pred_sbm), 1e-12)
}

})


Expand Down Expand Up @@ -350,7 +382,7 @@ test_that("SimpleSBM_fit 'Gaussian' model, undirected, no covariate", {
expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam)

## Estimation-----------------------------------------------------------------
mySBM$optimize(estimOptions=list(verbosity = 0))
BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0))
mySBM$setModel(3)

## Expectation
Expand All @@ -374,6 +406,14 @@ test_that("SimpleSBM_fit 'Gaussian' model, undirected, no covariate", {
expect_lt(rmse(mySBM$connectParam$mean, means), 1e-1)
expect_lt(1 - aricode::ARI(mySBM$memberships, mySampler$memberships), 1e-1)

## prediction wrt BM
for (Q in 1:4) {
pred_bm <- BM_out$prediction(Q = Q)
mySBM$setModel(Q)
pred_sbm <- predict(mySBM)
expect_lt( rmse(pred_bm, pred_sbm), 1e-12)
}

})

test_that("SimpleSBM_fit 'Gaussian' model, undirected, no covariate", {
Expand Down Expand Up @@ -421,7 +461,7 @@ test_that("SimpleSBM_fit 'Gaussian' model, undirected, no covariate", {
expect_equal(coef(mySBM, 'covariates') , mySBM$covarParam)

## Estimation-----------------------------------------------------------------
mySBM$optimize(estimOptions=list(verbosity = 0))
BM_out <- mySBM$optimize(estimOptions=list(verbosity = 0))
mySBM$setModel(3)

## Expectation
Expand All @@ -445,6 +485,14 @@ test_that("SimpleSBM_fit 'Gaussian' model, undirected, no covariate", {
expect_lt(rmse(sort(mySBM$connectParam$mean), means), 1e-1)
expect_lt(1 - aricode::ARI(mySBM$memberships, mySampler$memberships), 1e-1)

## prediction wrt BM
for (Q in 1:4) {
pred_bm <- BM_out$prediction(Q = Q)
mySBM$setModel(Q)
pred_sbm <- predict(mySBM)
expect_lt( rmse(pred_bm, pred_sbm), 1e-12)
}

})

test_that("active binding are working in the class", {
Expand Down
Loading

0 comments on commit 63572ad

Please sign in to comment.