From d9a22c98e2ac02b2c84648628d743664e4ec572b Mon Sep 17 00:00:00 2001 From: "George G. Vega Yon" Date: Wed, 12 Jun 2024 07:40:08 -0600 Subject: [PATCH] Adding dist functions for tools (plus test) --- epiworld.hpp | 264 +++++++++++++--------- include/epiworld/entity-bones.hpp | 2 +- include/epiworld/entity-meat.hpp | 2 +- include/epiworld/epiworld.hpp | 3 +- include/epiworld/model-meat-print.hpp | 23 +- include/epiworld/tool-bones.hpp | 18 +- include/epiworld/tool-meat.hpp | 81 +------ include/epiworld/virus-bones.hpp | 7 +- include/epiworld/virus-meat.hpp | 8 +- tests/05-mixing.cpp | 8 +- tests/06-mixing.cpp | 2 +- tests/07-entitifuns.cpp | 4 +- tests/09-distribute-tools-and-viruses.cpp | 44 +++- 13 files changed, 232 insertions(+), 234 deletions(-) diff --git a/epiworld.hpp b/epiworld.hpp index c42b1901..830fd972 100644 --- a/epiworld.hpp +++ b/epiworld.hpp @@ -19,7 +19,7 @@ /* Versioning */ #define EPIWORLD_VERSION_MAJOR 0 #define EPIWORLD_VERSION_MINOR 3 -#define EPIWORLD_VERSION_PATCH 1 +#define EPIWORLD_VERSION_PATCH 2 static const int epiworld_version_major = EPIWORLD_VERSION_MAJOR; static const int epiworld_version_minor = EPIWORLD_VERSION_MINOR; @@ -8880,26 +8880,11 @@ inline const Model & Model::print(bool lite) const if (i < n_tools_model) { - if (tool->get_prevalence_as_proportion()) - { - - printf_epiworld( - " - %s (baseline prevalence: %.2f%%)\n", - tool->get_name().c_str(), - tool->get_prevalence() * 100.0 - ); - - } - else - { - - printf_epiworld( - " - %s (baseline prevalence: %i seeds)\n", - tool->get_name().c_str(), - static_cast(tool->get_prevalence()) - ); + printf_epiworld( + " - %s\n", + tool->get_name().c_str() + ); - } } else { @@ -10007,10 +9992,7 @@ class Virus { VirusToAgentFun dist_fun = nullptr; public: - Virus( - std::string name = "unknown virus", - VirusToAgentFun dist_fun = nullptr - ); + Virus(std::string name = "unknown virus"); Virus( std::string name = "unknown virus", @@ -10122,7 +10104,7 @@ class Virus { */ ///@{ void distribute(Model * model); - void set_dist_fun(VirusToAgentFun fun); + void set_distribution(VirusToAgentFun fun); ///@} @@ -10365,11 +10347,9 @@ inline VirusFun virus_fun_logit( template inline Virus::Virus( - std::string name, - VirusToAgentFun dist_fun + std::string name ) { set_name(name); - set_dist_fun(dist_fun); } template @@ -10379,7 +10359,7 @@ inline Virus::Virus( bool prevalence_as_proportion ) { set_name(name); - set_dist_fun( + set_distribution( distribute_virus_randomly( prevalence, prevalence_as_proportion @@ -11003,7 +10983,7 @@ inline void Virus::distribute(Model * model) } template -inline void Virus::set_dist_fun(VirusToAgentFun fun) +inline void Virus::set_distribution(VirusToAgentFun fun) { dist_fun = fun; } @@ -11309,17 +11289,15 @@ class Tool { void set_agent(Agent * p, size_t idx); - epiworld_double prevalence = 0.0; - bool prevalence_as_proportion = false; ToolToAgentFun dist_fun = nullptr; public: + Tool(std::string name = "unknown tool"); Tool( - std::string name = "unknown tool", - epiworld_double prevalence = 0.0, - bool prevalence_as_proportion = false, - ToolToAgentFun dist_fun = nullptr - ); + std::string name, + epiworld_double prevalence, + bool as_proportion + ); void set_sequence(TSeq d); void set_sequence(std::shared_ptr d); @@ -11375,11 +11353,7 @@ class Tool { void print() const; void distribute(Model * model); - - void set_prevalence(epiworld_double p, bool as_proportion = false); - epiworld_double get_prevalence() const; - bool get_prevalence_as_proportion() const; - void set_dist_fun(ToolToAgentFun fun); + void set_distribution(ToolToAgentFun fun); }; @@ -11393,6 +11367,127 @@ class Tool { //////////////////////////////////////////////////////////////////////////////*/ +/*////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + + Start of -include/epiworld/tool-distribute-meat.hpp- + +//////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////*/ + + +#ifndef TOOL_DISTRIBUTE_MEAT_HPP +#define TOOL_DISTRIBUTE_MEAT_HPP + +/** + * @brief Distributes a tool to a set of agents. + * + * This function takes a vector of agent IDs and returns a lambda function that + * distributes a tool to each agent in the set. + * + * The lambda function takes a reference to a Tool object and a pointer to a + * Model object as parameters. It iterates over the agent IDs and adds the tool + * to each agent using the add_tool() method of the Model object. + * + * @tparam TSeq The sequence type used in the Tool and Model objects. + * @param agents_ids A vector of agent IDs representing the set of agents to + * distribute the tool to. + * @return A lambda function that distributes the tool to the set of agents. + */ +template +inline ToolToAgentFun distribute_tool_to_set( + std::vector< size_t > agents_ids +) { + + return [agents_ids]( + Tool & tool, Model * model + ) -> void + { + // Adding action + for (auto i: agents_ids) + { + model->get_agent(i).add_tool( + tool, + const_cast * >(model) + ); + } + }; + +} + +/** + * Function template to distribute a tool randomly to agents in a model. + * + * @tparam TSeq The sequence type used in the model. + * @param prevalence The prevalence of the tool in the population. + * @param as_proportion Flag indicating whether the prevalence is given as a + * proportion or an absolute value. + * @return A lambda function that distributes the tool randomly to agents in + * the model. + */ +template +inline ToolToAgentFun distribute_tool_randomly( + epiworld_double prevalence, + bool as_proportion = true +) { + + return [prevalence, as_proportion]( + Tool & tool, Model * model + ) -> void { + + // Picking how many + int n_to_distribute; + int n = model->size(); + if (as_proportion) + { + n_to_distribute = static_cast(std::floor(prevalence * n)); + + if (n_to_distribute == n) + n_to_distribute--; + } + else + { + n_to_distribute = static_cast(prevalence); + } + + if (n_to_distribute > n) + throw std::range_error("There are only " + std::to_string(n) + + " individuals in the population. Cannot add the tool to " + std::to_string(n_to_distribute)); + + std::vector< int > idx(n); + std::iota(idx.begin(), idx.end(), 0); + auto & population = model->get_agents(); + for (int i = 0u; i < n_to_distribute; ++i) + { + int loc = static_cast( + floor(model->runif() * n--) + ); + + if ((loc > 0) && (loc == n)) + loc--; + + population[idx[loc]].add_tool( + tool, + const_cast< Model * >(model) + ); + + std::swap(idx[loc], idx[n]); + + } + + }; + +} +#endif +/*////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + + End of -include/epiworld/tool-distribute-meat.hpp- + +//////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////*/ + + /*////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// @@ -11484,17 +11579,24 @@ inline ToolFun tool_fun_logit( } +template +inline Tool::Tool(std::string name) +{ + set_name(name); +} + template inline Tool::Tool( std::string name, epiworld_double prevalence, - bool as_proportion, - ToolToAgentFun dist_fun + bool as_proportion ) { set_name(name); - set_prevalence(prevalence, as_proportion); - set_dist_fun(dist_fun); + + set_distribution( + distribute_tool_randomly(prevalence, as_proportion) + ); } // template @@ -11917,80 +12019,16 @@ inline void Tool::distribute(Model * model) dist_fun(*this, model); - } else { - - // Picking how many - int n_to_distribute; - int n = model->size(); - if (prevalence_as_proportion) - { - n_to_distribute = static_cast(std::floor(prevalence * n)); - - if (n_to_distribute == n) - n_to_distribute--; - } - else - { - n_to_distribute = static_cast(prevalence); - } - - if (n_to_distribute > n) - throw std::range_error("There are only " + std::to_string(n) + - " individuals in the population. Cannot add the tool to " + std::to_string(n_to_distribute)); - - std::vector< int > idx(n); - std::iota(idx.begin(), idx.end(), 0); - auto & population = model->get_agents(); - for (int i = 0u; i < n_to_distribute; ++i) - { - int loc = static_cast( - floor(model->runif() * n--) - ); - - if ((loc > 0) && (loc == n)) - loc--; - - population[idx[loc]].add_tool( - *this, - const_cast< Model * >(model), - state_init, queue_init - ); - - std::swap(idx[loc], idx[n]); - - } - } } template -inline void Tool::set_dist_fun(ToolToAgentFun fun) +inline void Tool::set_distribution(ToolToAgentFun fun) { dist_fun = fun; } -template -inline epiworld_double Tool::get_prevalence() const -{ - return prevalence; -} - -template -inline void Tool::set_prevalence( - epiworld_double prevalence, - bool as_proportion -) -{ - this->prevalence = prevalence; - this->prevalence_as_proportion = as_proportion; -} - -template -inline bool Tool::get_prevalence_as_proportion() const -{ - return prevalence_as_proportion; -} #endif /*////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// @@ -12133,7 +12171,7 @@ class Entity { std::vector< size_t > & get_agents(); void print() const; - void set_dist_fun(EntityToAgentFun fun); + void set_distribution(EntityToAgentFun fun); }; @@ -12592,7 +12630,7 @@ inline void Entity::print() const } template -inline void Entity::set_dist_fun(EntityToAgentFun fun) +inline void Entity::set_distribution(EntityToAgentFun fun) { dist_fun = fun; } diff --git a/include/epiworld/entity-bones.hpp b/include/epiworld/entity-bones.hpp index 79ed8d57..222b4c72 100644 --- a/include/epiworld/entity-bones.hpp +++ b/include/epiworld/entity-bones.hpp @@ -120,7 +120,7 @@ class Entity { std::vector< size_t > & get_agents(); void print() const; - void set_dist_fun(EntityToAgentFun fun); + void set_distribution(EntityToAgentFun fun); }; diff --git a/include/epiworld/entity-meat.hpp b/include/epiworld/entity-meat.hpp index 33f7e6ac..185859e3 100644 --- a/include/epiworld/entity-meat.hpp +++ b/include/epiworld/entity-meat.hpp @@ -258,7 +258,7 @@ inline void Entity::print() const } template -inline void Entity::set_dist_fun(EntityToAgentFun fun) +inline void Entity::set_distribution(EntityToAgentFun fun) { dist_fun = fun; } diff --git a/include/epiworld/epiworld.hpp b/include/epiworld/epiworld.hpp index 1916534f..7e8f6624 100644 --- a/include/epiworld/epiworld.hpp +++ b/include/epiworld/epiworld.hpp @@ -19,7 +19,7 @@ /* Versioning */ #define EPIWORLD_VERSION_MAJOR 0 #define EPIWORLD_VERSION_MINOR 3 -#define EPIWORLD_VERSION_PATCH 1 +#define EPIWORLD_VERSION_PATCH 2 static const int epiworld_version_major = EPIWORLD_VERSION_MAJOR; static const int epiworld_version_minor = EPIWORLD_VERSION_MINOR; @@ -66,6 +66,7 @@ namespace epiworld { #include "tools-bones.hpp" #include "tool-bones.hpp" + #include "tool-distribute-meat.hpp" #include "tool-meat.hpp" #include "entity-bones.hpp" diff --git a/include/epiworld/model-meat-print.hpp b/include/epiworld/model-meat-print.hpp index 1d4d6bc0..1a66a85d 100644 --- a/include/epiworld/model-meat-print.hpp +++ b/include/epiworld/model-meat-print.hpp @@ -203,26 +203,11 @@ inline const Model & Model::print(bool lite) const if (i < n_tools_model) { - if (tool->get_prevalence_as_proportion()) - { - - printf_epiworld( - " - %s (baseline prevalence: %.2f%%)\n", - tool->get_name().c_str(), - tool->get_prevalence() * 100.0 - ); - - } - else - { - - printf_epiworld( - " - %s (baseline prevalence: %i seeds)\n", - tool->get_name().c_str(), - static_cast(tool->get_prevalence()) - ); + printf_epiworld( + " - %s\n", + tool->get_name().c_str() + ); - } } else { diff --git a/include/epiworld/tool-bones.hpp b/include/epiworld/tool-bones.hpp index 98d3c319..475a7522 100644 --- a/include/epiworld/tool-bones.hpp +++ b/include/epiworld/tool-bones.hpp @@ -50,17 +50,15 @@ class Tool { void set_agent(Agent * p, size_t idx); - epiworld_double prevalence = 0.0; - bool prevalence_as_proportion = false; ToolToAgentFun dist_fun = nullptr; public: + Tool(std::string name = "unknown tool"); Tool( - std::string name = "unknown tool", - epiworld_double prevalence = 0.0, - bool prevalence_as_proportion = false, - ToolToAgentFun dist_fun = nullptr - ); + std::string name, + epiworld_double prevalence, + bool as_proportion + ); void set_sequence(TSeq d); void set_sequence(std::shared_ptr d); @@ -116,11 +114,7 @@ class Tool { void print() const; void distribute(Model * model); - - void set_prevalence(epiworld_double p, bool as_proportion = false); - epiworld_double get_prevalence() const; - bool get_prevalence_as_proportion() const; - void set_dist_fun(ToolToAgentFun fun); + void set_distribution(ToolToAgentFun fun); }; diff --git a/include/epiworld/tool-meat.hpp b/include/epiworld/tool-meat.hpp index b4ae3906..9f8e11a5 100644 --- a/include/epiworld/tool-meat.hpp +++ b/include/epiworld/tool-meat.hpp @@ -80,17 +80,24 @@ inline ToolFun tool_fun_logit( } +template +inline Tool::Tool(std::string name) +{ + set_name(name); +} + template inline Tool::Tool( std::string name, epiworld_double prevalence, - bool as_proportion, - ToolToAgentFun dist_fun + bool as_proportion ) { set_name(name); - set_prevalence(prevalence, as_proportion); - set_dist_fun(dist_fun); + + set_distribution( + distribute_tool_randomly(prevalence, as_proportion) + ); } // template @@ -513,78 +520,14 @@ inline void Tool::distribute(Model * model) dist_fun(*this, model); - } else { - - // Picking how many - int n_to_distribute; - int n = model->size(); - if (prevalence_as_proportion) - { - n_to_distribute = static_cast(std::floor(prevalence * n)); - - if (n_to_distribute == n) - n_to_distribute--; - } - else - { - n_to_distribute = static_cast(prevalence); - } - - if (n_to_distribute > n) - throw std::range_error("There are only " + std::to_string(n) + - " individuals in the population. Cannot add the tool to " + std::to_string(n_to_distribute)); - - std::vector< int > idx(n); - std::iota(idx.begin(), idx.end(), 0); - auto & population = model->get_agents(); - for (int i = 0u; i < n_to_distribute; ++i) - { - int loc = static_cast( - floor(model->runif() * n--) - ); - - if ((loc > 0) && (loc == n)) - loc--; - - population[idx[loc]].add_tool( - *this, - const_cast< Model * >(model), - state_init, queue_init - ); - - std::swap(idx[loc], idx[n]); - - } - } } template -inline void Tool::set_dist_fun(ToolToAgentFun fun) +inline void Tool::set_distribution(ToolToAgentFun fun) { dist_fun = fun; } -template -inline epiworld_double Tool::get_prevalence() const -{ - return prevalence; -} - -template -inline void Tool::set_prevalence( - epiworld_double prevalence, - bool as_proportion -) -{ - this->prevalence = prevalence; - this->prevalence_as_proportion = as_proportion; -} - -template -inline bool Tool::get_prevalence_as_proportion() const -{ - return prevalence_as_proportion; -} #endif \ No newline at end of file diff --git a/include/epiworld/virus-bones.hpp b/include/epiworld/virus-bones.hpp index 2209a861..f4b6cbf4 100644 --- a/include/epiworld/virus-bones.hpp +++ b/include/epiworld/virus-bones.hpp @@ -58,10 +58,7 @@ class Virus { VirusToAgentFun dist_fun = nullptr; public: - Virus( - std::string name = "unknown virus", - VirusToAgentFun dist_fun = nullptr - ); + Virus(std::string name = "unknown virus"); Virus( std::string name = "unknown virus", @@ -173,7 +170,7 @@ class Virus { */ ///@{ void distribute(Model * model); - void set_dist_fun(VirusToAgentFun fun); + void set_distribution(VirusToAgentFun fun); ///@} diff --git a/include/epiworld/virus-meat.hpp b/include/epiworld/virus-meat.hpp index 95068e20..872385fe 100644 --- a/include/epiworld/virus-meat.hpp +++ b/include/epiworld/virus-meat.hpp @@ -81,11 +81,9 @@ inline VirusFun virus_fun_logit( template inline Virus::Virus( - std::string name, - VirusToAgentFun dist_fun + std::string name ) { set_name(name); - set_dist_fun(dist_fun); } template @@ -95,7 +93,7 @@ inline Virus::Virus( bool prevalence_as_proportion ) { set_name(name); - set_dist_fun( + set_distribution( distribute_virus_randomly( prevalence, prevalence_as_proportion @@ -719,7 +717,7 @@ inline void Virus::distribute(Model * model) } template -inline void Virus::set_dist_fun(VirusToAgentFun fun) +inline void Virus::set_distribution(VirusToAgentFun fun) { dist_fun = fun; } diff --git a/tests/05-mixing.cpp b/tests/05-mixing.cpp index 7d56f647..1c38fc71 100644 --- a/tests/05-mixing.cpp +++ b/tests/05-mixing.cpp @@ -28,7 +28,7 @@ EPIWORLD_TEST_CASE("SEIRMixing", "[SEIR-mixing]") { // Copy the original virus Virus<> v1 = model.get_virus(0); model.rm_virus(0); - v1.set_dist_fun(dist_virus<>(0)); + v1.set_distribution(dist_virus<>(0)); model.add_virus(v1); @@ -131,9 +131,9 @@ EPIWORLD_TEST_CASE("SEIRMixing", "[SEIR-mixing]") { // If entities don't have a dist function, then it should be // OK - e1.set_dist_fun(distribute_entity_randomly<>(2000, false, true)); - e2.set_dist_fun(distribute_entity_randomly<>(2000, false, true)); - e3.set_dist_fun(distribute_entity_randomly<>(2000, false, true)); + e1.set_distribution(distribute_entity_randomly<>(2000, false, true)); + e2.set_distribution(distribute_entity_randomly<>(2000, false, true)); + e3.set_distribution(distribute_entity_randomly<>(2000, false, true)); model.rm_entity(0); model.rm_entity(1); diff --git a/tests/06-mixing.cpp b/tests/06-mixing.cpp index 14980cde..68d04d31 100644 --- a/tests/06-mixing.cpp +++ b/tests/06-mixing.cpp @@ -29,7 +29,7 @@ EPIWORLD_TEST_CASE("SIRMixing", "[SIR-mixing]") { // Copy the original virus Virus<> v1 = model.get_virus(0); model.rm_virus(0); - v1.set_dist_fun(dist_virus<>(0)); + v1.set_distribution(dist_virus<>(0)); model.add_virus(v1); diff --git a/tests/07-entitifuns.cpp b/tests/07-entitifuns.cpp index 7e09def5..8925c07b 100644 --- a/tests/07-entitifuns.cpp +++ b/tests/07-entitifuns.cpp @@ -127,8 +127,8 @@ EPIWORLD_TEST_CASE("Entity member", "[Entity]") { ); // Updating the distribution - e1.set_dist_fun(distribute_entity_to_set<>(dist[0])); - e2.set_dist_fun(distribute_entity_to_set<>(dist[1])); + e1.set_distribution(distribute_entity_to_set<>(dist[0])); + e2.set_distribution(distribute_entity_to_set<>(dist[1])); model3.add_entity(e1); model3.add_entity(e2); diff --git a/tests/09-distribute-tools-and-viruses.cpp b/tests/09-distribute-tools-and-viruses.cpp index 71ae1423..83351df9 100644 --- a/tests/09-distribute-tools-and-viruses.cpp +++ b/tests/09-distribute-tools-and-viruses.cpp @@ -12,6 +12,14 @@ EPIWORLD_TEST_CASE("Distribution funs", "[DistFuns]") { "a virus", 10000u, 0.01, 4.0, 1.0, 1.0/10000.0 ); + Tool<> tool("vax"); + tool.set_susceptibility_reduction(0.5); + tool.set_distribution( + distribute_tool_randomly(0.1, true) + ); + + model_0.add_tool(tool); + model_0.run(0, 131); // Listing agents with viruses @@ -22,13 +30,26 @@ EPIWORLD_TEST_CASE("Distribution funs", "[DistFuns]") { got_it.push_back(i); } + // Listing agents with tools + std::vector< size_t > got_tool; + for (size_t i = 0; i < model_0.size(); i++) + { + if (model_0.get_agent(i).get_n_tools() != 0) + got_tool.push_back(i); + } + epimodels::ModelSIRCONN<> model_1( "a virus", 10000u, 0.01, 4.0, 1.0, 1.0/10000.0 ); - model_1.get_virus(0).set_dist_fun( + model_1.get_virus(0).set_distribution( distribute_virus_to_set<>(got_it) ); + + model_1.add_tool(tool); + model_1.get_tool(0).set_distribution( + distribute_tool_to_set<>(got_tool) + ); model_1.run(0, 131); @@ -39,11 +60,22 @@ EPIWORLD_TEST_CASE("Distribution funs", "[DistFuns]") { if (model_1.get_agent(i).get_virus() != nullptr) got_it1.push_back(i); } + + // Listing agents with tools + std::vector< size_t > got_tool1; + for (size_t i = 0; i < model_1.size(); i++) + { + if (model_1.get_agent(i).get_n_tools() != 0) + got_tool1.push_back(i); + } // Sorting got_it asc std::sort(got_it.begin(), got_it.end()); std::sort(got_it1.begin(), got_it1.end()); + std::sort(got_tool.begin(), got_tool.end()); + std::sort(got_tool1.begin(), got_tool1.end()); + // Comparing both sets: Finding the non-matching elements std::vector< size_t > diff; std::set_difference( @@ -52,9 +84,19 @@ EPIWORLD_TEST_CASE("Distribution funs", "[DistFuns]") { std::inserter(diff, diff.begin()) ); + std::vector< size_t > diff_tool; + std::set_difference( + got_tool.begin(), got_tool.end(), + got_tool1.begin(), got_tool1.end(), + std::inserter(diff_tool, diff_tool.begin()) + ); + #ifdef CATCH_CONFIG_MAIN REQUIRE(got_it.size() > 0); REQUIRE(got_it == got_it1); + + REQUIRE(got_tool.size() > 0); + REQUIRE(got_tool == got_tool1); #endif #ifndef CATCH_CONFIG_MAIN