From c22f77bebeea03cfc25148190b50552ca567e877 Mon Sep 17 00:00:00 2001 From: "George G. Vega Yon" Date: Tue, 11 Jun 2024 14:35:36 -0600 Subject: [PATCH] Reset entities properly --- epiworld.hpp | 244 +++++++++++++++++++++++-------- include/epiworld/agent-bones.hpp | 2 + include/epiworld/agent-meat.hpp | 34 +++++ include/epiworld/config.hpp | 6 +- include/epiworld/entity-meat.hpp | 19 ++- include/epiworld/epiworld.hpp | 2 +- include/epiworld/model-bones.hpp | 2 +- include/epiworld/model-meat.hpp | 4 +- tests/05-mixing.cpp | 32 ++-- 9 files changed, 257 insertions(+), 88 deletions(-) diff --git a/epiworld.hpp b/epiworld.hpp index 4f28fb54..df306c86 100644 --- a/epiworld.hpp +++ b/epiworld.hpp @@ -19,7 +19,7 @@ /* Versioning */ #define EPIWORLD_VERSION_MAJOR 0 #define EPIWORLD_VERSION_MINOR 3 -#define EPIWORLD_VERSION_PATCH 0 +#define EPIWORLD_VERSION_PATCH 1 static const int epiworld_version_major = EPIWORLD_VERSION_MAJOR; static const int epiworld_version_minor = EPIWORLD_VERSION_MINOR; @@ -129,7 +129,7 @@ template struct Event; template -using ActionFun = std::function&,Model*)>; +using EventFun = std::function&,Model*)>; /** * @brief Decides how to distribute viruses at initialization @@ -162,7 +162,7 @@ struct Event { Entity * entity; epiworld_fast_int new_state; epiworld_fast_int queue; - ActionFun call; + EventFun call; int idx_agent; int idx_object; public: @@ -189,7 +189,7 @@ struct Event { Entity * entity_, epiworld_fast_int new_state_, epiworld_fast_int queue_, - ActionFun call_, + EventFun call_, int idx_agent_, int idx_object_ ) : agent(agent_), virus(virus_), tool(tool_), entity(entity_), @@ -445,6 +445,10 @@ struct Event { #define EPI_NEW_ENTITYTOAGENTFUN(funname,tseq) inline void \ (funname)(epiworld::Entity & e, epiworld::Model * m) +#define EPI_NEW_ENTITYTOAGENTFUN_LAMBDA(funname,tseq) \ + epiworld::EntityToAgentFun funname = \ + [](epiworld::Entity & e, epiworld::Model * m) -> void + #endif /*////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// @@ -6201,7 +6205,7 @@ class Model { Entity * entity_, epiworld_fast_int new_state_, epiworld_fast_int queue_, - ActionFun call_, + EventFun call_, int idx_agent_, int idx_object_ ); @@ -6892,7 +6896,7 @@ inline void Model::events_add( Entity * entity_, epiworld_fast_int new_state_, epiworld_fast_int queue_, - ActionFun call_, + EventFun call_, int idx_agent_, int idx_object_ ) { @@ -6901,7 +6905,7 @@ inline void Model::events_add( #ifdef EPI_DEBUG if (nactions == 0) - throw std::logic_error("Actions cannot be zero!!"); + throw std::logic_error("Events cannot be zero!!"); #endif if (nactions > events.size()) @@ -12030,12 +12034,13 @@ class Entity { epiworld_fast_int queue_init = 0; ///< Change of state when added to agent. epiworld_fast_int queue_post = 0; ///< Change of state when removed from agent. -public: - epiworld_double prevalence = 0.0; bool prevalence_as_proportion = false; EntityToAgentFun dist_fun = nullptr; +public: + + /** * @brief Constructs an Entity object. * @@ -12099,6 +12104,11 @@ class Entity { void print() const; + void set_prevalence(epiworld_double p, bool as_proportion); + epiworld_double get_prevalence() const noexcept; + bool get_prevalence_as_proportion() const noexcept; + void set_dist_fun(EntityToAgentFun fun); + }; @@ -12125,50 +12135,101 @@ class Entity { #define EPIWORLD_ENTITY_MEAT_HPP template -EPI_NEW_ENTITYTOAGENTFUN(entity_to_unassigned_agents, TSeq) +inline EntityToAgentFun entity_to_unassigned_agents() { - // Preparing the sampling space - std::vector< size_t > idx; - for (const auto & a: m->get_agents()) - if (a.get_n_entities() == 0) - idx.push_back(a.get_id()); - size_t n = idx.size(); + return [](Entity & e, Model * m) -> void { - // Figuring out how many to sample - int n_to_sample; - if (e.prevalence_as_proportion) - { - n_to_sample = static_cast(std::floor(e.prevalence * n)); - if (n_to_sample > n) - --n_to_sample; + + // Preparing the sampling space + std::vector< size_t > idx; + for (const auto & a: m->get_agents()) + if (a.get_n_entities() == 0) + idx.push_back(a.get_id()); + size_t n = idx.size(); - } else - { - n_to_sample = static_cast(e.prevalence); - if (n_to_sample > n) - throw std::range_error("There are only " + std::to_string(n) + - " individuals in the population. Cannot add the entity to " + - std::to_string(n_to_sample)); - } + // Figuring out how many to sample + int n_to_sample; + if (e.prevalence_as_proportion) + { + n_to_sample = static_cast(std::floor(e.prevalence * n)); + if (n_to_sample > static_cast(n)) + --n_to_sample; + + } else + { + n_to_sample = static_cast(e.prevalence); + if (n_to_sample > static_cast(n)) + throw std::range_error("There are only " + std::to_string(n) + + " individuals in the population. Cannot add the entity to " + + std::to_string(n_to_sample)); + } + + int n_left = n; + for (size_t i = 0u; i < n_to_sample; ++i) + { + int loc = static_cast( + floor(m->runif() * n_left--) + ); + + // Correcting for possible overflow + if ((loc > 0) && (loc >= n_left)) + loc = n_left - 1; + + m->get_agent(idx[loc]).add_entity(e, m); - int n_left = n; - for (size_t i = 0u; i < n_to_sample; ++i) + std::swap(idx[loc], idx[n_left]); + + } + + }; + +} + +template +inline EntityToAgentFun entity_to_agent_range( + int from, + int to, + bool to_unassigned = false + ) { + + if (to_unassigned) { - int loc = static_cast( - floor(m->runif() * n_left--) - ); - // Correcting for possible overflow - if ((loc > 0) && (loc >= n_left)) - loc = n_left - 1; + return [from, to](Entity & e, Model * m) -> void { - m->get_agent(idx[loc]).add_entity(e, m); + auto & agents = m->get_agents(); + for (size_t i = from; i < to; ++i) + { + if (agents[i].get_n_entities() == 0) + e.add_agent(&agents[i], m); + else + throw std::logic_error( + "Agent " + std::to_string(i) + " already has an entity." + ); + } + + return; - std::swap(idx[loc], idx[n_left]); + }; } + else + { + + return [from, to](Entity & e, Model * m) -> void { + + auto & agents = m->get_agents(); + for (size_t i = from; i < to; ++i) + { + e.add_agent(&agents[i], m); + } + + return; + }; + + } } template @@ -12338,9 +12399,9 @@ inline void Entity::reset() sampled_agents_left.clear(); sampled_agents_left_n = 0u; - // Removing agents from entities - for (size_t i = 0u; i < n_agents; ++i) - this->rm_agent(i); + this->agents.clear(); + this->n_agents = 0u; + this->agents_location.clear(); return; @@ -12414,34 +12475,39 @@ inline void Entity::distribute() { // Picking how many - int nsampled; + int n_to_assign; if (prevalence_as_proportion) { - nsampled = static_cast(std::floor(prevalence * size())); + n_to_assign = static_cast(std::floor(prevalence * size())); } else { - nsampled = static_cast(prevalence); + n_to_assign = static_cast(prevalence); } - if (nsampled > static_cast(model->size())) + if (n_to_assign > static_cast(model->size())) throw std::range_error("There are only " + std::to_string(model->size()) + - " individuals in the population. Cannot add the entity to " + std::to_string(nsampled)); + " individuals in the population. Cannot add the entity to " + std::to_string(n_to_assign)); int n_left = n; std::iota(idx.begin(), idx.end(), 0); - while (nsampled > 0) + while (n_to_assign > 0) { int loc = static_cast( floor(model->runif() * n_left--) ); + + // Correcting for possible overflow + if ((loc > 0) && (loc >= n_left)) + loc = n_left - 1; - model->get_agent(idx[loc]).add_entity( - *this, this->model, this->state_init, this->queue_init - ); - - nsampled--; + auto & agent = model->get_agent(idx[loc]); + if (!agent.has_entity(id)) + agent.add_entity( + *this, this->model, this->state_init, this->queue_init + ); + std::swap(idx[loc], idx[n_left]); } @@ -12468,6 +12534,34 @@ inline void Entity::print() const ); } +template +inline void Entity::set_prevalence( + epiworld_double p, + bool as_proportion +) +{ + prevalence = p; + prevalence_as_proportion = as_proportion; +} + +template +inline epiworld_double Entity::get_prevalence() const noexcept +{ + return prevalence; +} + +template +inline bool Entity::get_prevalence_as_proportion() const noexcept +{ + return prevalence_as_proportion; +} + +template +inline void Entity::set_dist_fun(EntityToAgentFun fun) +{ + dist_fun = fun; +} + #endif /*////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// @@ -13529,6 +13623,8 @@ class Agent { bool has_virus(epiworld_fast_uint t) const; bool has_virus(std::string name) const; bool has_virus(const Virus & v) const; + bool has_entity(epiworld_fast_uint t) const; + bool has_entity(std::string name) const; void print(Model * model, bool compressed = false) const; @@ -14543,6 +14639,10 @@ inline void Agent::reset() this->tools.clear(); n_tools = 0u; + this->entities.clear(); + this->entities_locations.clear(); + this->n_entities = 0u; + this->state = 0u; this->state_prev = 0u; @@ -14610,6 +14710,30 @@ inline bool Agent::has_virus(const Virus & virus) const } +template +inline bool Agent::has_entity(epiworld_fast_uint t) const +{ + + for (auto & entity : entities) + if (entity == t) + return true; + + return false; + +} + +template +inline bool Agent::has_entity(std::string name) const +{ + + for (auto & entity : entities) + if (model->get_entity(entity).get_name() == name) + return true; + + return false; + +} + template inline void Agent::print( Model * model, @@ -14721,6 +14845,9 @@ inline const Entities_const Agent::get_entities() const template inline const Entity & Agent::get_entity(size_t i) const { + if (n_entities == 0) + throw std::range_error("Agent id " + std::to_string(id) + " has no entities."); + if (i >= n_entities) throw std::range_error("Trying to get to an agent's entity outside of the range."); @@ -14730,6 +14857,9 @@ inline const Entity & Agent::get_entity(size_t i) const template inline Entity & Agent::get_entity(size_t i) { + if (n_entities == 0) + throw std::range_error("Agent id " + std::to_string(id) + " has no entities."); + if (i >= n_entities) throw std::range_error("Trying to get to an agent's entity outside of the range."); @@ -16769,7 +16899,7 @@ inline ModelSURV::ModelSURV( model.add_param(prob_noreinfect, "Prob. no reinfect"); // Virus ------------------------------------------------------------------ - epiworld::Virus covid("Covid19", prevalence, true); + epiworld::Virus covid("Covid19", prevalence, false); covid.set_state(LATENT, RECOVERED, REMOVED); covid.set_post_immunity(&model("Prob. no reinfect")); covid.set_prob_death(&model("Prob. death")); @@ -19804,7 +19934,7 @@ inline ModelSEIRMixing & ModelSEIRMixing::run( int seed ) { - + Model::run(ndays, seed); return *this; diff --git a/include/epiworld/agent-bones.hpp b/include/epiworld/agent-bones.hpp index c3d2fe26..c4e75dea 100644 --- a/include/epiworld/agent-bones.hpp +++ b/include/epiworld/agent-bones.hpp @@ -264,6 +264,8 @@ class Agent { bool has_virus(epiworld_fast_uint t) const; bool has_virus(std::string name) const; bool has_virus(const Virus & v) const; + bool has_entity(epiworld_fast_uint t) const; + bool has_entity(std::string name) const; void print(Model * model, bool compressed = false) const; diff --git a/include/epiworld/agent-meat.hpp b/include/epiworld/agent-meat.hpp index aadb8f83..27d4785e 100644 --- a/include/epiworld/agent-meat.hpp +++ b/include/epiworld/agent-meat.hpp @@ -635,6 +635,10 @@ inline void Agent::reset() this->tools.clear(); n_tools = 0u; + this->entities.clear(); + this->entities_locations.clear(); + this->n_entities = 0u; + this->state = 0u; this->state_prev = 0u; @@ -702,6 +706,30 @@ inline bool Agent::has_virus(const Virus & virus) const } +template +inline bool Agent::has_entity(epiworld_fast_uint t) const +{ + + for (auto & entity : entities) + if (entity == t) + return true; + + return false; + +} + +template +inline bool Agent::has_entity(std::string name) const +{ + + for (auto & entity : entities) + if (model->get_entity(entity).get_name() == name) + return true; + + return false; + +} + template inline void Agent::print( Model * model, @@ -813,6 +841,9 @@ inline const Entities_const Agent::get_entities() const template inline const Entity & Agent::get_entity(size_t i) const { + if (n_entities == 0) + throw std::range_error("Agent id " + std::to_string(id) + " has no entities."); + if (i >= n_entities) throw std::range_error("Trying to get to an agent's entity outside of the range."); @@ -822,6 +853,9 @@ inline const Entity & Agent::get_entity(size_t i) const template inline Entity & Agent::get_entity(size_t i) { + if (n_entities == 0) + throw std::range_error("Agent id " + std::to_string(id) + " has no entities."); + if (i >= n_entities) throw std::range_error("Trying to get to an agent's entity outside of the range."); diff --git a/include/epiworld/config.hpp b/include/epiworld/config.hpp index 852243d9..49aa6b46 100644 --- a/include/epiworld/config.hpp +++ b/include/epiworld/config.hpp @@ -91,7 +91,7 @@ template struct Event; template -using ActionFun = std::function&,Model*)>; +using EventFun = std::function&,Model*)>; /** * @brief Decides how to distribute viruses at initialization @@ -124,7 +124,7 @@ struct Event { Entity * entity; epiworld_fast_int new_state; epiworld_fast_int queue; - ActionFun call; + EventFun call; int idx_agent; int idx_object; public: @@ -151,7 +151,7 @@ struct Event { Entity * entity_, epiworld_fast_int new_state_, epiworld_fast_int queue_, - ActionFun call_, + EventFun call_, int idx_agent_, int idx_object_ ) : agent(agent_), virus(virus_), tool(tool_), entity(entity_), diff --git a/include/epiworld/entity-meat.hpp b/include/epiworld/entity-meat.hpp index 94820303..bb52454e 100644 --- a/include/epiworld/entity-meat.hpp +++ b/include/epiworld/entity-meat.hpp @@ -266,9 +266,9 @@ inline void Entity::reset() sampled_agents_left.clear(); sampled_agents_left_n = 0u; - // Removing agents from entities - for (size_t i = 0u; i < n_agents; ++i) - this->rm_agent(i); + this->agents.clear(); + this->n_agents = 0u; + this->agents_location.clear(); return; @@ -358,7 +358,7 @@ inline void Entity::distribute() int n_left = n; std::iota(idx.begin(), idx.end(), 0); - for (int i = 0; i < n_to_assign; ++i) + while (n_to_assign > 0) { int loc = static_cast( floor(model->runif() * n_left--) @@ -368,10 +368,13 @@ inline void Entity::distribute() if ((loc > 0) && (loc >= n_left)) loc = n_left - 1; - model->get_agent(idx[loc]).add_entity( - *this, this->model, this->state_init, this->queue_init - ); - + auto & agent = model->get_agent(idx[loc]); + + if (!agent.has_entity(id)) + agent.add_entity( + *this, this->model, this->state_init, this->queue_init + ); + std::swap(idx[loc], idx[n_left]); } diff --git a/include/epiworld/epiworld.hpp b/include/epiworld/epiworld.hpp index d6d229df..cb080683 100644 --- a/include/epiworld/epiworld.hpp +++ b/include/epiworld/epiworld.hpp @@ -19,7 +19,7 @@ /* Versioning */ #define EPIWORLD_VERSION_MAJOR 0 #define EPIWORLD_VERSION_MINOR 3 -#define EPIWORLD_VERSION_PATCH 0 +#define EPIWORLD_VERSION_PATCH 1 static const int epiworld_version_major = EPIWORLD_VERSION_MAJOR; static const int epiworld_version_minor = EPIWORLD_VERSION_MINOR; diff --git a/include/epiworld/model-bones.hpp b/include/epiworld/model-bones.hpp index f57db90d..fc0d0c7e 100644 --- a/include/epiworld/model-bones.hpp +++ b/include/epiworld/model-bones.hpp @@ -220,7 +220,7 @@ class Model { Entity * entity_, epiworld_fast_int new_state_, epiworld_fast_int queue_, - ActionFun call_, + EventFun call_, int idx_agent_, int idx_object_ ); diff --git a/include/epiworld/model-meat.hpp b/include/epiworld/model-meat.hpp index 75331639..95bc32c4 100644 --- a/include/epiworld/model-meat.hpp +++ b/include/epiworld/model-meat.hpp @@ -157,7 +157,7 @@ inline void Model::events_add( Entity * entity_, epiworld_fast_int new_state_, epiworld_fast_int queue_, - ActionFun call_, + EventFun call_, int idx_agent_, int idx_object_ ) { @@ -166,7 +166,7 @@ inline void Model::events_add( #ifdef EPI_DEBUG if (nactions == 0) - throw std::logic_error("Actions cannot be zero!!"); + throw std::logic_error("Events cannot be zero!!"); #endif if (nactions > events.size()) diff --git a/tests/05-mixing.cpp b/tests/05-mixing.cpp index 9dcb1f20..441ed414 100644 --- a/tests/05-mixing.cpp +++ b/tests/05-mixing.cpp @@ -129,22 +129,22 @@ EPIWORLD_TEST_CASE("SEIRMixing", "[SEIR-mixing]") { REQUIRE_THAT(totals, Catch::Equals(expected_totals)); #endif - // // If entities don't have a dist function, then it should be - // // OK - // e1.set_dist_fun(nullptr); - // e2.set_dist_fun(nullptr); - // e3.set_dist_fun(nullptr); - - // model.rm_entity(0); - // model.rm_entity(1); - // model.rm_entity(2); - - // model.add_entity(e1); - // model.add_entity(e2); - // model.add_entity(e3); - - // // Running and checking the results - // model.run(50, 123); + // If entities don't have a dist function, then it should be + // OK + e1.set_dist_fun(nullptr); + e2.set_dist_fun(nullptr); + e3.set_dist_fun(nullptr); + + model.rm_entity(0); + model.rm_entity(1); + model.rm_entity(2); + + model.add_entity(e1); + model.add_entity(e2); + model.add_entity(e3); + + // Running and checking the results + model.run(50, 123); #ifndef CATCH_CONFIG_MAIN