diff --git a/epiworld.hpp b/epiworld.hpp index 8c2e778e..bf42c046 100644 --- a/epiworld.hpp +++ b/epiworld.hpp @@ -6351,6 +6351,12 @@ class Model { const std::vector & entities_ids ); + void load_agents_entities_ties( + const int * agents_id, + const int * entities_id, + size_t n + ); + /** * @name Accessing population of the model * @@ -7955,7 +7961,7 @@ inline void Model::rm_entity(size_t entity_id) entity.reset(); // How should - if (entity_pos != (entities.size() - 1)) + if (entity_pos != (static_cast(entities.size()) - 1)) std::swap(entities[entity_pos], entities[entities.size() - 1]); entities.pop_back(); @@ -8101,46 +8107,87 @@ inline void Model::load_agents_entities_ties( const std::vector< int > & entities_ids ) { + // Checking the size if (agents_ids.size() != entities_ids.size()) throw std::length_error( - std::string("agents_ids (") + + std::string("The size of agents_ids (") + std::to_string(agents_ids.size()) + std::string(") and entities_ids (") + std::to_string(entities_ids.size()) + - std::string(") should match.") + std::string(") must be the same.") ); + return this->load_agents_entities_ties( + agents_ids.data(), + entities_ids.data(), + agents_ids.size() + ); + +} + +template +inline void Model::load_agents_entities_ties( + const int * agents_ids, + const int * entities_ids, + size_t n +) { + + auto get_agent = [agents_ids](int i) -> int { + return *(agents_ids + i); + }; - size_t n_entries = agents_ids.size(); - for (size_t i = 0u; i < n_entries; ++i) + auto get_entity = [entities_ids](int i) -> int { + return *(entities_ids + i); + }; + + for (size_t i = 0u; i < n; ++i) { - if (agents_ids[i] >= this->population.size()) + if (get_agent(i) < 0) + throw std::length_error( + std::string("agents_ids[") + + std::to_string(i) + + std::string("] = ") + + std::to_string(get_agent(i)) + + std::string(" is negative.") + ); + + if (get_entity(i) < 0) + throw std::length_error( + std::string("entities_ids[") + + std::to_string(i) + + std::string("] = ") + + std::to_string(get_entity(i)) + + std::string(" is negative.") + ); + + int pop_size = static_cast(this->population.size()); + if (get_agent(i) >= pop_size) throw std::length_error( std::string("agents_ids[") + std::to_string(i) + std::string("] = ") + - std::to_string(agents_ids[i]) + + std::to_string(get_agent(i)) + std::string(" is out of range (population size: ") + - std::to_string(this->population.size()) + + std::to_string(pop_size) + std::string(").") ); - - if (entities_ids[i] >= this->entities.size()) + int ent_size = static_cast(this->entities.size()); + if (get_entity(i) >= ent_size) throw std::length_error( std::string("entities_ids[") + std::to_string(i) + std::string("] = ") + - std::to_string(entities_ids[i]) + + std::to_string(get_entity(i)) + std::string(" is out of range (entities size: ") + - std::to_string(this->entities.size()) + + std::to_string(ent_size) + std::string(").") ); // Adding the entity to the agent - this->population[agents_ids[i]].add_entity( - this->entities[entities_ids[i]], + this->population[get_agent(i)].add_entity( + this->entities[get_entity(i)], nullptr /* Immediately add it to the agent */ ); @@ -12088,6 +12135,10 @@ class Entity { void distribute(); + std::vector< size_t > & get_agents(); + + void print() const; + }; @@ -12392,6 +12443,24 @@ inline void Entity::distribute() } +template +inline std::vector< size_t > & Entity::get_agents() +{ + return agents; +} + +template +inline void Entity::print() const +{ + + printf_epiworld( + "Entity '%s' (id %i) with %i agents.\n", + this->entity_name.c_str(), + static_cast(id), + static_cast(n_agents) + ); +} + #endif /*////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// @@ -19657,7 +19726,10 @@ inline void ModelSEIRMixing::update_infected() { if (a.get_state() == ModelSEIRMixing::INFECTED) - infected[a.get_entity(0u).get_id()].push_back(&a); + { + if (a.get_n_entities() > 0u) + infected[a.get_entity(0u).get_id()].push_back(&a); + } } @@ -20190,7 +20262,10 @@ inline void ModelSIRMixing::update_infected_list() { if (a.get_state() == ModelSIRMixing::INFECTED) - infected[a.get_entity(0u).get_id()].push_back(&a); + { + if (a.get_n_entities() > 0u) + infected[a.get_entity(0u).get_id()].push_back(&a); + } }