Skip to content

Commit

Permalink
Adding dist funs factories + testing
Browse files Browse the repository at this point in the history
  • Loading branch information
gvegayon committed Jun 11, 2024
1 parent 2050a55 commit 6e4105d
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 8 deletions.
4 changes: 2 additions & 2 deletions epiworld.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12135,7 +12135,7 @@ class Entity {
#define EPIWORLD_ENTITY_MEAT_HPP

template <typename TSeq = EPI_DEFAULT_TSEQ>
inline EntityToAgentFun<TSeq> entity_to_unassigned_agents()
inline EntityToAgentFun<TSeq> distribute_entity_to_unassigned()
{

return [](Entity<TSeq> & e, Model<TSeq> * m) -> void {
Expand Down Expand Up @@ -12187,7 +12187,7 @@ inline EntityToAgentFun<TSeq> entity_to_unassigned_agents()
}

template<typename TSeq = int>
inline EntityToAgentFun<TSeq> entity_to_agent_range(
inline EntityToAgentFun<TSeq> distribute_entity_to_range(
int from,
int to,
bool to_unassigned = false
Expand Down
23 changes: 20 additions & 3 deletions include/epiworld/entity-meat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#define EPIWORLD_ENTITY_MEAT_HPP

template <typename TSeq = EPI_DEFAULT_TSEQ>
inline EntityToAgentFun<TSeq> entity_to_unassigned_agents()
inline EntityToAgentFun<TSeq> distribute_entity_to_unassigned()
{

return [](Entity<TSeq> & e, Model<TSeq> * m) -> void {
Expand Down Expand Up @@ -53,8 +53,8 @@ inline EntityToAgentFun<TSeq> entity_to_unassigned_agents()

}

template<typename TSeq = int>
inline EntityToAgentFun<TSeq> entity_to_agent_range(
template<typename TSeq = EPI_DEFAULT_TSEQ>
inline EntityToAgentFun<TSeq> distribute_entity_to_range(
int from,
int to,
bool to_unassigned = false
Expand Down Expand Up @@ -99,6 +99,23 @@ inline EntityToAgentFun<TSeq> entity_to_agent_range(
}
}


template<typename TSeq = EPI_DEFAULT_TSEQ>
inline EntityToAgentFun<TSeq> distribute_entity_to_set(
std::vector< size_t > & idx
) {

return [idx](Entity<TSeq> & e, Model<TSeq> * m) -> void {

for (const auto & i: idx)
{
e.add_agent(&m->get_agent(i), m);
}

};

}

template<typename TSeq>
inline void Entity<TSeq>::add_agent(
Agent<TSeq> & p,
Expand Down
3 changes: 2 additions & 1 deletion include/epiworld/models/seirmixing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class ModelSEIRMixing : public epiworld::Model<TSeq>
* @param transmission_rate The transmission rate of the disease in the model.
* @param avg_incubation_days The average incubation period of the disease in the model.
* @param recovery_rate The recovery rate of the disease in the model.
* @param contact_matrix The contact matrix between entities in the model.
* @param contact_matrix The contact matrix between entities in the model. Specified in
* column-major order.
*/
ModelSEIRMixing(
ModelSEIRMixing<TSeq> & model,
Expand Down
22 changes: 22 additions & 0 deletions include/epiworld/virus-meat.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,28 @@
#ifndef EPIWORLD_VIRUS_MEAT_HPP
#define EPIWORLD_VIRUS_MEAT_HPP

// Factory functions to distribute the virus
template<typename TSeq = EPI_DEFAULT_TSEQ>
inline VirusToAgentFun<TSeq> distribute_virus_to_set(
std::vector< size_t > agents_ids
) {

return [agents_ids](
Virus<TSeq> & virus, Model<TSeq> * model
) -> void
{
// Adding action
for (auto i: agents_ids)
{
model->get_agent(i).set_virus(
virus,
const_cast<Model<TSeq> * >(model)
);
}
};

}

/**
* @brief Factory function of VirusFun base on logit
*
Expand Down
1 change: 0 additions & 1 deletion tests/05-mixing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ EPIWORLD_TEST_CASE("SEIRMixing", "[SEIR-mixing]") {
n3++;
}


#ifndef CATCH_CONFIG_MAIN
return 0;
#endif
Expand Down
38 changes: 37 additions & 1 deletion tests/07-entitifuns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ EPIWORLD_TEST_CASE("Entity member", "[Entity]") {
);

// Generating two entities, 100 distribution
auto dfun = epiworld::entity_to_unassigned_agents<>();
auto dfun = epiworld::distribute_entity_to_unassigned<>();
Entity<> e1("Entity 1", 5000, false, dfun);
Entity<> e2("Entity 2", 5000, false, dfun);

Expand Down Expand Up @@ -110,6 +110,42 @@ EPIWORLD_TEST_CASE("Entity member", "[Entity]") {
REQUIRE(std::fabs(n2 - std::pow(p, 2.0) * N) < 100);
#endif

// Checking distribution via sets
std::vector< std::vector< size_t > > dist = {
model.get_entity(0).get_agents(),
model.get_entity(1).get_agents()
};

// Creating a copy of the model
epimodels::ModelSIRCONN<> model3(
"Flu", // std::string vname,
N, // epiworld_fast_uint n,
0.01, // epiworld_double prevalence,
10.0, // epiworld_double contact_rate,
1.0, // epiworld_double transmission_rate,
1.0/2.0// epiworld_double recovery_rate
);

// Updating the distribution
e1.set_dist_fun(distribute_entity_to_set<>(dist[0]));
e2.set_dist_fun(distribute_entity_to_set<>(dist[1]));

model3.add_entity(e1);
model3.add_entity(e2);

model3.run(50, 123);

// Should match the results!
std::vector< std::vector< size_t > > dist2 = {
model3.get_entity(0).get_agents(),
model3.get_entity(1).get_agents()
};

#ifdef CATCH_CONFIG_MAIN
REQUIRE(dist[0] == dist2[0]);
REQUIRE(dist[1] == dist2[1]);
#endif

#ifndef CATCH_CONFIG_MAIN
return 0;
#endif
Expand Down
56 changes: 56 additions & 0 deletions tests/08-mixing-entities.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#include "tests.hpp"

using namespace epiworld;

int main()
{

std::vector< double > contact_matrix = {
0.9, 0.1, 0.1,
0.05, 0.8, 0.2,
0.05, 0.1, 0.7
};

epimodels::ModelSEIRMixing<> model(
"Flu", // std::string vname,
9000, // epiworld_fast_uint n,
1.0/9000.0, // epiworld_double prevalence,
20.0, // epiworld_double contact_rate,
0.1, // epiworld_double transmission_rate,
7.0, // epiworld_double avg_incubation_days,
1.0/7.0,// epiworld_double recovery_rate,
contact_matrix
);

// Creating three groups
Entity<> e1("Entity 1", 3000, false);
Entity<> e2("Entity 2", 3000, false);
Entity<> e3("Entity 3", 3000, false);

model.add_entity(e1);
model.add_entity(e2);
model.add_entity(e3);

// Running and checking the results
model.run(100, 13);
model.print();

// Checking entity assignment
auto agents1 = model.get_entity(0).get_agents();
auto agents2 = model.get_entity(1).get_agents();
auto agents3 = model.get_entity(2).get_agents();

// How many agents have 0, 1, 2, or 3 entities?
std::vector< int > nentities(4, 0);
for (const auto & a: agents1)
nentities[model.get_agent(a).get_n_entities()]++;

for (const auto & a: agents2)
nentities[model.get_agent(a).get_n_entities()]++;

for (const auto & a: agents3)
nentities[model.get_agent(a).get_n_entities()]++;

return 0;

}
1 change: 1 addition & 0 deletions tests/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
#include "05-mixing.cpp"
#include "06-mixing.cpp"
#include "07-entitifuns.cpp"
#include "09-distribute-tools-and-viruses.cpp"

0 comments on commit 6e4105d

Please sign in to comment.