Skip to content

Commit

Permalink
Try to optimize competing hazards
Browse files Browse the repository at this point in the history
  • Loading branch information
plietar committed Jul 10, 2024
1 parent e2242b2 commit 7a118d1
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 30 deletions.
57 changes: 40 additions & 17 deletions R/competing_hazards.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,62 +13,85 @@ CompetingOutcome <- R6::R6Class(
stop("size must be positive integer")
}
private$targeted_process <- targeted_process
self$rates <- rep(0, size)

self$target <- individual::Bitset$new(size)
self$rates <- NULL
},
set_rates = function(rates){
set_rates = function(target, rates){
stopifnot(target$size() == length(rates))

# TODO: add an assign method to Bitset
self$target$or(target)
self$rates <- rates
},
execute = function(t, target){
private$targeted_process(t, target)
self$rates <- rep(0, length(self$rates))
},
reset = function() {
self$target$clear()
self$rates <- NULL
},
target = NULL,
rates = NULL
)
)

CompetingHazard <- R6::R6Class(
"CompetingHazard",
private = list(
outcomes = list(),
size = NULL,
outcomes = list(),
# RNG is passed in because mockery is not able to stub runif
# TODO: change when fixed
rng = NULL
),
public = list(
initialize = function(outcomes, rng = runif){
initialize = function(size, outcomes, rng = runif){
if (length(outcomes) == 0){
stop("At least one outcome must be provided")
}
if (!all(sapply(outcomes, function(x) inherits(x, "CompetingOutcome")))){
stop("All outcomes must be of class CompetingOutcome")
}
private$size <- size
private$outcomes <- outcomes
private$size <- length(outcomes[[1]]$rates)
private$rng <- rng
},
resolve = function(t){
event_rates <- do.call(
'cbind',
lapply(private$outcomes, function(x) x$rates)
)
candidates <- individual::Bitset$new(private$size)
for (o in private$outcomes) {
candidates$or(o$target)
}
targets.vector <- candidates$to_vector()

rates <- matrix(ncol = length(private$outcomes), nrow = candidates$size(), 0)
for (i in seq_along(private$outcomes)) {
idx <- match(
private$outcomes[[i]]$target$to_vector(),
targets.vector)

total_rates <- rowSums(event_rates)
probs <- rate_to_prob(total_rates) * (event_rates / total_rates)
rates[idx, i] <- private$outcomes[[i]]$rates
}

total_rates <- rowSums(rates)
probs <- rate_to_prob(total_rates) * (rates / total_rates)
probs[is.na(probs)] <- 0

rng <- private$rng(private$size)
rng <- private$rng(candidates$size())

cumulative <- rep(0, candidates$size())

cumulative <- rep(0, private$size)
for (o in seq_along(private$outcomes)) {
next_cumulative <- cumulative + probs[,o]
selected <- which((rng > cumulative) & (rng <= next_cumulative))
selected <- (rng > cumulative) & (rng <= next_cumulative)
cumulative <- next_cumulative

target <- individual::Bitset$new(private$size)$insert(selected)
if (target$size() > 0){
# TODO: change bitset_at to accept logical array
target <- bitset_at(candidates, which(selected))
if (target$size() > 0) {
private$outcomes[[o]]$execute(t, target)
}
private$outcomes[[o]]$reset()
}
}
)
Expand Down
5 changes: 4 additions & 1 deletion R/disease_progression.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ create_recovery_rates_process <- function(
recovery_outcome
) {
function(timestep){
recovery_outcome$set_rates(variables$recovery_rates$get_values())
target <- variables$state$get_index_of(c("U", "Tr"))
recovery_outcome$set_rates(
target,
variables$recovery_rates$get_values(target))
}
}

Expand Down
6 changes: 3 additions & 3 deletions R/human_infection.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ calculate_infections <- function(
)

## capture infection rates to resolve in competing hazards
infection_rates <- rep(0, length = parameters$human_population)
infection_rates[source_vector] <- prob_to_rate(prob)
infection_outcome$set_rates(infection_rates)
infection_outcome$set_rates(
source_humans,
prob_to_rate(prob))
}

#' @title Assigns infections to appropriate human states
Expand Down
3 changes: 2 additions & 1 deletion R/processes.R
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ create_processes <- function(

# Resolve competing hazards of infection with disease progression
CompetingHazard$new(
outcomes = list(infection_outcome, recovery_outcome)
outcomes = list(infection_outcome, recovery_outcome),
size = parameters$human_population
)$resolve
)

Expand Down
21 changes: 15 additions & 6 deletions tests/testthat/test-competing-hazards.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

test_that("hazard resolves two disjoint outcomes", {
size <- 4
population <- individual::Bitset$new(size)$not()

outcome_1_process <- mockery::mock()
outcome_1 <- CompetingOutcome$new(
targeted_process = outcome_1_process,
Expand All @@ -17,12 +19,13 @@ test_that("hazard resolves two disjoint outcomes", {
)

hazard <- CompetingHazard$new(
size = size,
outcomes = list(outcome_1, outcome_2),
rng = mockery::mock(c(.05, .3, .2, .5))
)

outcome_1$set_rates(c(10, 0, 10, 0))
outcome_2$set_rates(c(0, 10, 0, 10))
outcome_1$set_rates(population, c(10, 0, 10, 0))
outcome_2$set_rates(population, c(0, 10, 0, 10))

hazard$resolve(0)

Expand All @@ -42,6 +45,8 @@ test_that("hazard resolves two disjoint outcomes", {

test_that("hazard resolves two competing outcomes", {
size <- 4
population <- individual::Bitset$new(size)$not()

outcome_1_process <- mockery::mock()
outcome_1 <- CompetingOutcome$new(
targeted_process = outcome_1_process,
Expand All @@ -54,12 +59,13 @@ test_that("hazard resolves two competing outcomes", {
)

hazard <- CompetingHazard$new(
size = size,
outcomes = list(outcome_1, outcome_2),
rng = mockery::mock(c(.7, .3, .2, .6))
)

outcome_1$set_rates(c(5, 5, 5, 5))
outcome_2$set_rates(c(5, 5, 5, 5))
outcome_1$set_rates(population, c(5, 5, 5, 5))
outcome_2$set_rates(population, c(5, 5, 5, 5))

hazard$resolve(0)

Expand All @@ -79,6 +85,8 @@ test_that("hazard resolves two competing outcomes", {

test_that("hazard resolves partial outcomes", {
size <- 4
population <- individual::Bitset$new(size)$not()

outcome_1_process <- mockery::mock()
outcome_1 <- CompetingOutcome$new(
targeted_process = outcome_1_process,
Expand All @@ -91,12 +99,13 @@ test_that("hazard resolves partial outcomes", {
)

hazard <- CompetingHazard$new(
size = size,
outcomes = list(outcome_1, outcome_2),
rng = mockery::mock(c(.8, .4, .2, .6))
)

outcome_1$set_rates(prob_to_rate(rep(0.5, size)))
outcome_2$set_rates(prob_to_rate(rep(0.5, size)))
outcome_1$set_rates(population, prob_to_rate(rep(0.5, size)))
outcome_2$set_rates(population, prob_to_rate(rep(0.5, size)))

hazard$resolve(0)

Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-infection-integration.R
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ test_that('prophylaxis is considered for medicated humans', {
targeted_process = function(timestep, target){
infection_outcome_process(timestep, target, variables, renderer, parameters)
},
size = parameters$human_population
size = 4
)

infection_rates <- calculate_infections(
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-pev.R
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ test_that('Infection considers pev efficacy', {
targeted_process = function(timestep, target){
infection_process_resolved_hazard(timestep, target, variables, renderer, parameters)
},
size = parameters$human_population
size = 4
)

# remove randomness from infection sampling
Expand Down

0 comments on commit 7a118d1

Please sign in to comment.