Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing bug in mixing models and implements new entity-related methods #18

Merged
merged 4 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 91 additions & 16 deletions epiworld.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6351,6 +6351,12 @@ class Model {
const std::vector<int> & entities_ids
);

void load_agents_entities_ties(
const int * agents_id,
const int * entities_id,
size_t n
);

/**
* @name Accessing population of the model
*
Expand Down Expand Up @@ -7955,7 +7961,7 @@ inline void Model<TSeq>::rm_entity(size_t entity_id)
entity.reset();

// How should
if (entity_pos != (entities.size() - 1))
if (entity_pos != (static_cast<int>(entities.size()) - 1))
std::swap(entities[entity_pos], entities[entities.size() - 1]);

entities.pop_back();
Expand Down Expand Up @@ -8101,46 +8107,87 @@ inline void Model<TSeq>::load_agents_entities_ties(
const std::vector< int > & entities_ids
) {

// Checking the size
if (agents_ids.size() != entities_ids.size())
throw std::length_error(
std::string("agents_ids (") +
std::string("The size of agents_ids (") +
std::to_string(agents_ids.size()) +
std::string(") and entities_ids (") +
std::to_string(entities_ids.size()) +
std::string(") should match.")
std::string(") must be the same.")
);

return this->load_agents_entities_ties(
agents_ids.data(),
entities_ids.data(),
agents_ids.size()
);

}

template<typename TSeq>
inline void Model<TSeq>::load_agents_entities_ties(
const int * agents_ids,
const int * entities_ids,
size_t n
) {

auto get_agent = [agents_ids](int i) -> int {
return *(agents_ids + i);
};

size_t n_entries = agents_ids.size();
for (size_t i = 0u; i < n_entries; ++i)
auto get_entity = [entities_ids](int i) -> int {
return *(entities_ids + i);
};

for (size_t i = 0u; i < n; ++i)
{

if (agents_ids[i] >= this->population.size())
if (get_agent(i) < 0)
throw std::length_error(
std::string("agents_ids[") +
std::to_string(i) +
std::string("] = ") +
std::to_string(get_agent(i)) +
std::string(" is negative.")
);

if (get_entity(i) < 0)
throw std::length_error(
std::string("entities_ids[") +
std::to_string(i) +
std::string("] = ") +
std::to_string(get_entity(i)) +
std::string(" is negative.")
);

int pop_size = static_cast<int>(this->population.size());
if (get_agent(i) >= pop_size)
throw std::length_error(
std::string("agents_ids[") +
std::to_string(i) +
std::string("] = ") +
std::to_string(agents_ids[i]) +
std::to_string(get_agent(i)) +
std::string(" is out of range (population size: ") +
std::to_string(this->population.size()) +
std::to_string(pop_size) +
std::string(").")
);


if (entities_ids[i] >= this->entities.size())
int ent_size = static_cast<int>(this->entities.size());
if (get_entity(i) >= ent_size)
throw std::length_error(
std::string("entities_ids[") +
std::to_string(i) +
std::string("] = ") +
std::to_string(entities_ids[i]) +
std::to_string(get_entity(i)) +
std::string(" is out of range (entities size: ") +
std::to_string(this->entities.size()) +
std::to_string(ent_size) +
std::string(").")
);

// Adding the entity to the agent
this->population[agents_ids[i]].add_entity(
this->entities[entities_ids[i]],
this->population[get_agent(i)].add_entity(
this->entities[get_entity(i)],
nullptr /* Immediately add it to the agent */
);

Expand Down Expand Up @@ -12088,6 +12135,10 @@ class Entity {

void distribute();

std::vector< size_t > & get_agents();

void print() const;

};


Expand Down Expand Up @@ -12392,6 +12443,24 @@ inline void Entity<TSeq>::distribute()

}

template<typename TSeq>
inline std::vector< size_t > & Entity<TSeq>::get_agents()
{
return agents;
}

template<typename TSeq>
inline void Entity<TSeq>::print() const
{

printf_epiworld(
"Entity '%s' (id %i) with %i agents.\n",
this->entity_name.c_str(),
static_cast<int>(id),
static_cast<int>(n_agents)
);
}

#endif
/*//////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -19657,7 +19726,10 @@ inline void ModelSEIRMixing<TSeq>::update_infected()
{

if (a.get_state() == ModelSEIRMixing<TSeq>::INFECTED)
infected[a.get_entity(0u).get_id()].push_back(&a);
{
if (a.get_n_entities() > 0u)
infected[a.get_entity(0u).get_id()].push_back(&a);
}

}

Expand Down Expand Up @@ -20190,7 +20262,10 @@ inline void ModelSIRMixing<TSeq>::update_infected_list()
{

if (a.get_state() == ModelSIRMixing<TSeq>::INFECTED)
infected[a.get_entity(0u).get_id()].push_back(&a);
{
if (a.get_n_entities() > 0u)
infected[a.get_entity(0u).get_id()].push_back(&a);
}

}

Expand Down
4 changes: 4 additions & 0 deletions include/epiworld/entity-bones.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ class Entity {

void distribute();

std::vector< size_t > & get_agents();

void print() const;

};


Expand Down
18 changes: 18 additions & 0 deletions include/epiworld/entity-meat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,4 +280,22 @@ inline void Entity<TSeq>::distribute()

}

template<typename TSeq>
inline std::vector< size_t > & Entity<TSeq>::get_agents()
{
return agents;
}

template<typename TSeq>
inline void Entity<TSeq>::print() const
{

printf_epiworld(
"Entity '%s' (id %i) with %i agents.\n",
this->entity_name.c_str(),
static_cast<int>(id),
static_cast<int>(n_agents)
);
}

#endif
6 changes: 6 additions & 0 deletions include/epiworld/model-bones.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,12 @@ class Model {
const std::vector<int> & entities_ids
);

void load_agents_entities_ties(
const int * agents_id,
const int * entities_id,
size_t n
);

/**
* @name Accessing population of the model
*
Expand Down
69 changes: 55 additions & 14 deletions include/epiworld/model-meat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,7 @@ inline void Model<TSeq>::rm_entity(size_t entity_id)
entity.reset();

// How should
if (entity_pos != (entities.size() - 1))
if (entity_pos != (static_cast<int>(entities.size()) - 1))
std::swap(entities[entity_pos], entities[entities.size() - 1]);

entities.pop_back();
Expand Down Expand Up @@ -1359,46 +1359,87 @@ inline void Model<TSeq>::load_agents_entities_ties(
const std::vector< int > & entities_ids
) {

// Checking the size
if (agents_ids.size() != entities_ids.size())
throw std::length_error(
std::string("agents_ids (") +
std::string("The size of agents_ids (") +
std::to_string(agents_ids.size()) +
std::string(") and entities_ids (") +
std::to_string(entities_ids.size()) +
std::string(") should match.")
std::string(") must be the same.")
);

return this->load_agents_entities_ties(
agents_ids.data(),
entities_ids.data(),
agents_ids.size()
);

}

size_t n_entries = agents_ids.size();
for (size_t i = 0u; i < n_entries; ++i)
template<typename TSeq>
inline void Model<TSeq>::load_agents_entities_ties(
const int * agents_ids,
const int * entities_ids,
size_t n
) {

auto get_agent = [agents_ids](int i) -> int {
return *(agents_ids + i);
};

auto get_entity = [entities_ids](int i) -> int {
return *(entities_ids + i);
};

for (size_t i = 0u; i < n; ++i)
{

if (agents_ids[i] >= this->population.size())
if (get_agent(i) < 0)
throw std::length_error(
std::string("agents_ids[") +
std::to_string(i) +
std::string("] = ") +
std::to_string(get_agent(i)) +
std::string(" is negative.")
);

if (get_entity(i) < 0)
throw std::length_error(
std::string("entities_ids[") +
std::to_string(i) +
std::string("] = ") +
std::to_string(get_entity(i)) +
std::string(" is negative.")
);

int pop_size = static_cast<int>(this->population.size());
if (get_agent(i) >= pop_size)
throw std::length_error(
std::string("agents_ids[") +
std::to_string(i) +
std::string("] = ") +
std::to_string(agents_ids[i]) +
std::to_string(get_agent(i)) +
std::string(" is out of range (population size: ") +
std::to_string(this->population.size()) +
std::to_string(pop_size) +
std::string(").")
);


if (entities_ids[i] >= this->entities.size())
int ent_size = static_cast<int>(this->entities.size());
if (get_entity(i) >= ent_size)
throw std::length_error(
std::string("entities_ids[") +
std::to_string(i) +
std::string("] = ") +
std::to_string(entities_ids[i]) +
std::to_string(get_entity(i)) +
std::string(" is out of range (entities size: ") +
std::to_string(this->entities.size()) +
std::to_string(ent_size) +
std::string(").")
);

// Adding the entity to the agent
this->population[agents_ids[i]].add_entity(
this->entities[entities_ids[i]],
this->population[get_agent(i)].add_entity(
this->entities[get_entity(i)],
nullptr /* Immediately add it to the agent */
);

Expand Down
5 changes: 4 additions & 1 deletion include/epiworld/models/seirmixing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ inline void ModelSEIRMixing<TSeq>::update_infected()
{

if (a.get_state() == ModelSEIRMixing<TSeq>::INFECTED)
infected[a.get_entity(0u).get_id()].push_back(&a);
{
if (a.get_n_entities() > 0u)
infected[a.get_entity(0u).get_id()].push_back(&a);
}

}

Expand Down
5 changes: 4 additions & 1 deletion include/epiworld/models/sirmixing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ inline void ModelSIRMixing<TSeq>::update_infected_list()
{

if (a.get_state() == ModelSIRMixing<TSeq>::INFECTED)
infected[a.get_entity(0u).get_id()].push_back(&a);
{
if (a.get_n_entities() > 0u)
infected[a.get_entity(0u).get_id()].push_back(&a);
}

}

Expand Down
Loading