diff --git a/epiworld.hpp b/epiworld.hpp index df306c86..e69c343c 100644 --- a/epiworld.hpp +++ b/epiworld.hpp @@ -12135,7 +12135,7 @@ class Entity { #define EPIWORLD_ENTITY_MEAT_HPP template <typename TSeq = EPI_DEFAULT_TSEQ> -inline EntityToAgentFun<TSeq> entity_to_unassigned_agents() +inline EntityToAgentFun<TSeq> distribute_entity_to_unassigned() { return [](Entity<TSeq> & e, Model<TSeq> * m) -> void { @@ -12187,7 +12187,7 @@ inline EntityToAgentFun<TSeq> entity_to_unassigned_agents() } template<typename TSeq = int> -inline EntityToAgentFun<TSeq> entity_to_agent_range( +inline EntityToAgentFun<TSeq> distribute_entity_to_range( int from, int to, bool to_unassigned = false diff --git a/include/epiworld/entity-meat.hpp b/include/epiworld/entity-meat.hpp index d0482059..7c6ac237 100644 --- a/include/epiworld/entity-meat.hpp +++ b/include/epiworld/entity-meat.hpp @@ -2,7 +2,7 @@ #define EPIWORLD_ENTITY_MEAT_HPP template <typename TSeq = EPI_DEFAULT_TSEQ> -inline EntityToAgentFun<TSeq> entity_to_unassigned_agents() +inline EntityToAgentFun<TSeq> distribute_entity_to_unassigned() { return [](Entity<TSeq> & e, Model<TSeq> * m) -> void { @@ -53,8 +53,8 @@ inline EntityToAgentFun<TSeq> entity_to_unassigned_agents() } -template<typename TSeq = int> -inline EntityToAgentFun<TSeq> entity_to_agent_range( +template<typename TSeq = EPI_DEFAULT_TSEQ> +inline EntityToAgentFun<TSeq> distribute_entity_to_range( int from, int to, bool to_unassigned = false @@ -99,6 +99,23 @@ inline EntityToAgentFun<TSeq> entity_to_agent_range( } } + +template<typename TSeq = EPI_DEFAULT_TSEQ> +inline EntityToAgentFun<TSeq> distribute_entity_to_set( + std::vector< size_t > & idx + ) { + + return [idx](Entity<TSeq> & e, Model<TSeq> * m) -> void { + + for (const auto & i: idx) + { + e.add_agent(&m->get_agent(i), m); + } + + }; + +} + template<typename TSeq> inline void Entity<TSeq>::add_agent( Agent<TSeq> & p, diff --git a/include/epiworld/models/seirmixing.hpp b/include/epiworld/models/seirmixing.hpp index 362d33cb..85fc2900 100644 --- a/include/epiworld/models/seirmixing.hpp +++ b/include/epiworld/models/seirmixing.hpp @@ -43,7 +43,8 @@ class ModelSEIRMixing : public epiworld::Model<TSeq> * @param transmission_rate The transmission rate of the disease in the model. * @param avg_incubation_days The average incubation period of the disease in the model. * @param recovery_rate The recovery rate of the disease in the model. - * @param contact_matrix The contact matrix between entities in the model. + * @param contact_matrix The contact matrix between entities in the model. Specified in + * column-major order. */ ModelSEIRMixing( ModelSEIRMixing<TSeq> & model, diff --git a/include/epiworld/virus-meat.hpp b/include/epiworld/virus-meat.hpp index 1270a6f0..f2931129 100644 --- a/include/epiworld/virus-meat.hpp +++ b/include/epiworld/virus-meat.hpp @@ -1,6 +1,28 @@ #ifndef EPIWORLD_VIRUS_MEAT_HPP #define EPIWORLD_VIRUS_MEAT_HPP +// Factory functions to distribute the virus +template<typename TSeq = EPI_DEFAULT_TSEQ> +inline VirusToAgentFun<TSeq> distribute_virus_to_set( + std::vector< size_t > agents_ids +) { + + return [agents_ids]( + Virus<TSeq> & virus, Model<TSeq> * model + ) -> void + { + // Adding action + for (auto i: agents_ids) + { + model->get_agent(i).set_virus( + virus, + const_cast<Model<TSeq> * >(model) + ); + } + }; + +} + /** * @brief Factory function of VirusFun base on logit * diff --git a/tests/05-mixing.cpp b/tests/05-mixing.cpp index 94ca006a..467c1a0f 100644 --- a/tests/05-mixing.cpp +++ b/tests/05-mixing.cpp @@ -177,7 +177,6 @@ EPIWORLD_TEST_CASE("SEIRMixing", "[SEIR-mixing]") { n3++; } - #ifndef CATCH_CONFIG_MAIN return 0; #endif diff --git a/tests/07-entitifuns.cpp b/tests/07-entitifuns.cpp index f6ef081f..06728dc9 100644 --- a/tests/07-entitifuns.cpp +++ b/tests/07-entitifuns.cpp @@ -19,7 +19,7 @@ EPIWORLD_TEST_CASE("Entity member", "[Entity]") { ); // Generating two entities, 100 distribution - auto dfun = epiworld::entity_to_unassigned_agents<>(); + auto dfun = epiworld::distribute_entity_to_unassigned<>(); Entity<> e1("Entity 1", 5000, false, dfun); Entity<> e2("Entity 2", 5000, false, dfun); @@ -110,6 +110,42 @@ EPIWORLD_TEST_CASE("Entity member", "[Entity]") { REQUIRE(std::fabs(n2 - std::pow(p, 2.0) * N) < 100); #endif + // Checking distribution via sets + std::vector< std::vector< size_t > > dist = { + model.get_entity(0).get_agents(), + model.get_entity(1).get_agents() + }; + + // Creating a copy of the model + epimodels::ModelSIRCONN<> model3( + "Flu", // std::string vname, + N, // epiworld_fast_uint n, + 0.01, // epiworld_double prevalence, + 10.0, // epiworld_double contact_rate, + 1.0, // epiworld_double transmission_rate, + 1.0/2.0// epiworld_double recovery_rate + ); + + // Updating the distribution + e1.set_dist_fun(distribute_entity_to_set<>(dist[0])); + e2.set_dist_fun(distribute_entity_to_set<>(dist[1])); + + model3.add_entity(e1); + model3.add_entity(e2); + + model3.run(50, 123); + + // Should match the results! + std::vector< std::vector< size_t > > dist2 = { + model3.get_entity(0).get_agents(), + model3.get_entity(1).get_agents() + }; + + #ifdef CATCH_CONFIG_MAIN + REQUIRE(dist[0] == dist2[0]); + REQUIRE(dist[1] == dist2[1]); + #endif + #ifndef CATCH_CONFIG_MAIN return 0; #endif diff --git a/tests/08-mixing-entities.cpp b/tests/08-mixing-entities.cpp new file mode 100644 index 00000000..2f819a1f --- /dev/null +++ b/tests/08-mixing-entities.cpp @@ -0,0 +1,56 @@ +#include "tests.hpp" + +using namespace epiworld; + +int main() +{ + + std::vector< double > contact_matrix = { + 0.9, 0.1, 0.1, + 0.05, 0.8, 0.2, + 0.05, 0.1, 0.7 + }; + + epimodels::ModelSEIRMixing<> model( + "Flu", // std::string vname, + 9000, // epiworld_fast_uint n, + 1.0/9000.0, // epiworld_double prevalence, + 20.0, // epiworld_double contact_rate, + 0.1, // epiworld_double transmission_rate, + 7.0, // epiworld_double avg_incubation_days, + 1.0/7.0,// epiworld_double recovery_rate, + contact_matrix + ); + + // Creating three groups + Entity<> e1("Entity 1", 3000, false); + Entity<> e2("Entity 2", 3000, false); + Entity<> e3("Entity 3", 3000, false); + + model.add_entity(e1); + model.add_entity(e2); + model.add_entity(e3); + + // Running and checking the results + model.run(100, 13); + model.print(); + + // Checking entity assignment + auto agents1 = model.get_entity(0).get_agents(); + auto agents2 = model.get_entity(1).get_agents(); + auto agents3 = model.get_entity(2).get_agents(); + + // How many agents have 0, 1, 2, or 3 entities? + std::vector< int > nentities(4, 0); + for (const auto & a: agents1) + nentities[model.get_agent(a).get_n_entities()]++; + + for (const auto & a: agents2) + nentities[model.get_agent(a).get_n_entities()]++; + + for (const auto & a: agents3) + nentities[model.get_agent(a).get_n_entities()]++; + + return 0; + +} \ No newline at end of file diff --git a/tests/main.cpp b/tests/main.cpp index cbfdd968..1fe02680 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -18,3 +18,4 @@ #include "05-mixing.cpp" #include "06-mixing.cpp" #include "07-entitifuns.cpp" +#include "09-distribute-tools-and-viruses.cpp"