Skip to content

Commit

Permalink
Virus dist params detach from class
Browse files Browse the repository at this point in the history
  • Loading branch information
gvegayon committed Jun 12, 2024
1 parent cb65638 commit 70ba8e4
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 285 deletions.
184 changes: 42 additions & 142 deletions epiworld.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8834,26 +8834,10 @@ inline const Model<TSeq> & Model<TSeq>::print(bool lite) const
if (i < n_viruses_model)
{

if (virus->get_prevalence_as_proportion())
{

printf_epiworld(
" - %s (baseline prevalence: %.2f%%)\n",
virus->get_name().c_str(),
virus->get_prevalence() * 100.00
);

}
else
{

printf_epiworld(
" - %s (baseline prevalence: %i seeds)\n",
virus->get_name().c_str(),
static_cast<int>(virus->get_prevalence())
);

}
printf_epiworld(
" - %s\n",
virus->get_name().c_str()
);

} else {

Expand Down Expand Up @@ -10020,18 +10004,20 @@ class Virus {
epiworld_fast_int queue_removed = -99; ///< Change of state when agent is removed

// Information about how distribution works
epiworld_double prevalence = 0.0;
bool prevalence_as_proportion = false;
VirusToAgentFun<TSeq> dist_fun = nullptr;

public:
Virus(
std::string name = "unknown virus",
epiworld_double prevalence = 0.0,
bool prevalence_as_proportion = false,
VirusToAgentFun<TSeq> dist_fun = nullptr
);

Virus(
std::string name = "unknown virus",
epiworld_double prevalence = 0.0,
bool as_proportion = true
);

void mutate(Model<TSeq> * model);
void set_mutation(MutFun<TSeq> fun);

Expand Down Expand Up @@ -10135,9 +10121,6 @@ class Virus {
* @brief Get information about the prevalence of the virus
*/
///@{
epiworld_double get_prevalence() const;
void set_prevalence(epiworld_double prevalence, bool as_proportion);
bool get_prevalence_as_proportion() const;
void distribute(Model<TSeq> * model);
void set_dist_fun(VirusToAgentFun<TSeq> fun);
///@}
Expand Down Expand Up @@ -10383,16 +10366,27 @@ inline VirusFun<TSeq> virus_fun_logit(
template<typename TSeq>
inline Virus<TSeq>::Virus(
std::string name,
epiworld_double prevalence,
bool prevalence_as_proportion,
VirusToAgentFun<TSeq> dist_fun
) {
set_name(name);

set_prevalence(prevalence, prevalence_as_proportion);
set_dist_fun(dist_fun);
}

template<typename TSeq>
inline Virus<TSeq>::Virus(
std::string name,
epiworld_double prevalence,
bool prevalence_as_proportion
) {
set_name(name);
set_dist_fun(
distribute_virus_randomly<TSeq>(
prevalence,
prevalence_as_proportion
)
);
}

template<typename TSeq>
inline void Virus<TSeq>::mutate(
Model<TSeq> * model
Expand Down Expand Up @@ -10995,38 +10989,6 @@ inline void Virus<TSeq>::print() const

}

template<typename TSeq>
inline epiworld_double Virus<TSeq>::get_prevalence() const
{
return prevalence;
}

template<typename TSeq>
inline bool Virus<TSeq>::get_prevalence_as_proportion() const
{
return prevalence_as_proportion;
}

template<typename TSeq>
inline void Virus<TSeq>::set_prevalence(
epiworld_double preval,
bool as_proportion
)
{

if (as_proportion) {

if ((preval < 0.0) || (preval > 1.0))
throw std::range_error(
"The prevalence should be between 0 and 1. " +
std::string("Got ") + std::to_string(preval)
);
}

prevalence = preval;
prevalence_as_proportion = as_proportion;
}

template<typename TSeq>
inline void Virus<TSeq>::distribute(Model<TSeq> * model)
{
Expand All @@ -11036,65 +10998,6 @@ inline void Virus<TSeq>::distribute(Model<TSeq> * model)

dist_fun(*this, model);

} else {

// Figuring out how what agents are available
std::vector< size_t > idx;
for (const auto & agent: model->get_agents())
if (agent.get_virus() == nullptr)
idx.push_back(agent.get_id());

// Picking how many
size_t n = model->size();
int n_available = static_cast<int>(idx.size());
int n_to_sample;
if (prevalence_as_proportion)
{
n_to_sample = static_cast<int>(std::floor(prevalence * n));

if (n_to_sample == static_cast<int>(n))
n_to_sample--;
}
else
{
n_to_sample = static_cast<int>(prevalence);
}

if (n_to_sample > n_available)
throw std::range_error(
"There are only " + std::to_string(n_available) +
" individuals with no virus in the population. " +
"Cannot add the virus to " +
std::to_string(n_to_sample)
);

auto & population = model->get_agents();
for (int i = 0; i < n_to_sample; ++i)
{

int loc = static_cast<epiworld_fast_uint>(
floor(model->runif() * (n_available--))
);

// Correcting for possible overflow
if ((loc > 0) && (loc >= n_available))
loc = n_available - 1;

Agent<TSeq> & agent = population[idx[loc]];

// Adding action
agent.set_virus(
*this,
const_cast<Model<TSeq> * >(model),
this->state_init,
this->queue_init
);

// Adjusting sample
std::swap(idx[loc], idx[n_available]);

}

}

}
Expand Down Expand Up @@ -15821,13 +15724,12 @@ inline std::function<void(epiworld::Model<TSeq>*)> create_init_function_sir(
// Figuring out information about the viruses
double tot = 0.0;
double n = static_cast<double>(model->size());
for (const auto & virus: model->get_viruses())
for (const auto & agent: model->get_agents())
{
if (virus->get_prevalence_as_proportion())
tot += virus->get_prevalence();
else
tot += virus->get_prevalence() / n;
if (agent.get_virus() != nullptr)
tot += 1.0;
}
tot /= n;

// Putting the total into context
double tot_left = 1.0 - tot;
Expand Down Expand Up @@ -15895,13 +15797,12 @@ inline std::function<void(epiworld::Model<TSeq>*)> create_init_function_sird(
// Figuring out information about the viruses
double tot = 0.0;
double n = static_cast<double>(model->size());
for (const auto & virus: model->get_viruses())
for (const auto & agent: model->get_agents())
{
if (virus->get_prevalence_as_proportion())
tot += virus->get_prevalence();
else
tot += virus->get_prevalence() / n;
if (agent.get_virus() != nullptr)
tot += 1.0;
}
tot /= n;

// Putting the total into context
double tot_left = 1.0 - tot;
Expand Down Expand Up @@ -15973,13 +15874,12 @@ inline std::function<void(epiworld::Model<TSeq>*)> create_init_function_seir(
// Figuring out information about the viruses
double tot = 0.0;
double n = static_cast<double>(model->size());
for (const auto & virus: model->get_viruses())
for (const auto & agent: model->get_agents())
{
if (virus->get_prevalence_as_proportion())
tot += virus->get_prevalence();
else
tot += virus->get_prevalence() / n;
if (agent.get_virus() != nullptr)
tot += 1.0;
}
tot /= n;

// Putting the total into context
double tot_left = 1.0 - tot;
Expand Down Expand Up @@ -16055,13 +15955,13 @@ inline std::function<void(epiworld::Model<TSeq>*)> create_init_function_seird(
// Figuring out information about the viruses
double tot = 0.0;
double n = static_cast<double>(model->size());
for (const auto & virus: model->get_viruses())

for (const auto & agent: model->get_agents())
{
if (virus->get_prevalence_as_proportion())
tot += virus->get_prevalence();
else
tot += virus->get_prevalence() / n;
if (agent.get_virus() != nullptr)
tot += 1.0;
}
tot /= n;

// Putting the total into context
double tot_left = 1.0 - tot;
Expand Down
6 changes: 5 additions & 1 deletion examples/00-hello-world/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ int main()
model.add_state("Removed");

// Adding the tool and virus
epiworld::Virus<int> virus("covid 19", 50, false);
epiworld::Virus<int> virus(
"covid 19",
distribute_virus_randomly<int>(50, false)
);

virus.set_post_immunity(1.0);
virus.set_state(1,2,3);
virus.set_prob_death(.01);
Expand Down
24 changes: 4 additions & 20 deletions include/epiworld/model-meat-print.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,26 +157,10 @@ inline const Model<TSeq> & Model<TSeq>::print(bool lite) const
if (i < n_viruses_model)
{

if (virus->get_prevalence_as_proportion())
{

printf_epiworld(
" - %s (baseline prevalence: %.2f%%)\n",
virus->get_name().c_str(),
virus->get_prevalence() * 100.00
);

}
else
{

printf_epiworld(
" - %s (baseline prevalence: %i seeds)\n",
virus->get_name().c_str(),
static_cast<int>(virus->get_prevalence())
);

}
printf_epiworld(
" - %s\n",
virus->get_name().c_str()
);

} else {

Expand Down
37 changes: 17 additions & 20 deletions include/epiworld/models/init-functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,12 @@ inline std::function<void(epiworld::Model<TSeq>*)> create_init_function_sir(
// Figuring out information about the viruses
double tot = 0.0;
double n = static_cast<double>(model->size());
for (const auto & virus: model->get_viruses())
for (const auto & agent: model->get_agents())
{
if (virus->get_prevalence_as_proportion())
tot += virus->get_prevalence();
else
tot += virus->get_prevalence() / n;
if (agent.get_virus() != nullptr)
tot += 1.0;
}
tot /= n;

// Putting the total into context
double tot_left = 1.0 - tot;
Expand Down Expand Up @@ -104,13 +103,12 @@ inline std::function<void(epiworld::Model<TSeq>*)> create_init_function_sird(
// Figuring out information about the viruses
double tot = 0.0;
double n = static_cast<double>(model->size());
for (const auto & virus: model->get_viruses())
for (const auto & agent: model->get_agents())
{
if (virus->get_prevalence_as_proportion())
tot += virus->get_prevalence();
else
tot += virus->get_prevalence() / n;
if (agent.get_virus() != nullptr)
tot += 1.0;
}
tot /= n;

// Putting the total into context
double tot_left = 1.0 - tot;
Expand Down Expand Up @@ -182,13 +180,12 @@ inline std::function<void(epiworld::Model<TSeq>*)> create_init_function_seir(
// Figuring out information about the viruses
double tot = 0.0;
double n = static_cast<double>(model->size());
for (const auto & virus: model->get_viruses())
for (const auto & agent: model->get_agents())
{
if (virus->get_prevalence_as_proportion())
tot += virus->get_prevalence();
else
tot += virus->get_prevalence() / n;
if (agent.get_virus() != nullptr)
tot += 1.0;
}
tot /= n;

// Putting the total into context
double tot_left = 1.0 - tot;
Expand Down Expand Up @@ -264,13 +261,13 @@ inline std::function<void(epiworld::Model<TSeq>*)> create_init_function_seird(
// Figuring out information about the viruses
double tot = 0.0;
double n = static_cast<double>(model->size());
for (const auto & virus: model->get_viruses())

for (const auto & agent: model->get_agents())
{
if (virus->get_prevalence_as_proportion())
tot += virus->get_prevalence();
else
tot += virus->get_prevalence() / n;
if (agent.get_virus() != nullptr)
tot += 1.0;
}
tot /= n;

// Putting the total into context
double tot_left = 1.0 - tot;
Expand Down
Loading

0 comments on commit 70ba8e4

Please sign in to comment.