Skip to content

Commit

Permalink
15 implement missing modelentity methods (#17)
Browse files Browse the repository at this point in the history
* Adding rm_entity, get_entity, and moving entity paramters out of model

* Updating single header
  • Loading branch information
gvegayon authored May 2, 2024
1 parent cf9c8eb commit d89f393
Show file tree
Hide file tree
Showing 7 changed files with 238 additions and 136 deletions.
187 changes: 119 additions & 68 deletions epiworld.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6124,9 +6124,6 @@ class Model {
std::vector< ToolToAgentFun<TSeq> > tools_dist_funs = {};

std::vector< Entity<TSeq> > entities = {};
std::vector< epiworld_double > prevalence_entity = {};
std::vector< bool > prevalence_entity_as_proportion = {};
std::vector< EntityToAgentFun<TSeq> > entities_dist_funs = {};
std::vector< Entity<TSeq> > entities_backup = {};

std::mt19937 engine;
Expand Down Expand Up @@ -6331,7 +6328,7 @@ class Model {
void add_entity_fun(Entity<TSeq> e, EntityToAgentFun<TSeq> fun);
void rm_virus(size_t virus_pos);
void rm_tool(size_t tool_pos);
void rm_entity(size_t entity_pos);
void rm_entity(size_t entity_id);
///@}

/**
Expand Down Expand Up @@ -6384,6 +6381,8 @@ class Model {

std::vector< Agent<TSeq> > & get_agents(); ///< Returns a reference to the vector of agents.

Agent<TSeq> & get_agent(size_t i);

std::vector< epiworld_fast_uint > get_agents_states() const; ///< Returns a vector with the states of the agents.

std::vector< Viruses_const<TSeq> > get_agents_viruses() const; ///< Returns a const vector with the viruses of the agents.
Expand All @@ -6392,6 +6391,8 @@ class Model {

std::vector< Entity<TSeq> > & get_entities();

Entity<TSeq> & get_entity(size_t entity_id, int * entity_pos = nullptr);

Model<TSeq> & agents_smallworld(
epiworld_fast_uint n = 1000,
epiworld_fast_uint k = 5,
Expand Down Expand Up @@ -7129,9 +7130,6 @@ inline Model<TSeq>::Model(const Model<TSeq> & model) :
prevalence_tool_as_proportion(model.prevalence_tool_as_proportion),
tools_dist_funs(model.tools_dist_funs),
entities(model.entities),
prevalence_entity(model.prevalence_entity),
prevalence_entity_as_proportion(model.prevalence_entity_as_proportion),
entities_dist_funs(model.entities_dist_funs),
entities_backup(model.entities_backup),
rewire_fun(model.rewire_fun),
rewire_prop(model.rewire_prop),
Expand Down Expand Up @@ -7206,9 +7204,6 @@ inline Model<TSeq>::Model(Model<TSeq> && model) :
tools_dist_funs(std::move(model.tools_dist_funs)),
// Entities
entities(std::move(model.entities)),
prevalence_entity(std::move(model.prevalence_entity)),
prevalence_entity_as_proportion(std::move(model.prevalence_entity_as_proportion)),
entities_dist_funs(std::move(model.entities_dist_funs)),
entities_backup(std::move(model.entities_backup)),
// Pseudo-RNG
engine(std::move(model.engine)),
Expand Down Expand Up @@ -7285,9 +7280,6 @@ inline Model<TSeq> & Model<TSeq>::operator=(const Model<TSeq> & m)
tools_dist_funs = m.tools_dist_funs;

entities = m.entities;
prevalence_entity = m.prevalence_entity;
prevalence_entity_as_proportion = m.prevalence_entity_as_proportion;
entities_dist_funs = m.entities_dist_funs;
entities_backup = m.entities_backup;

rewire_fun = m.rewire_fun;
Expand Down Expand Up @@ -7347,6 +7339,12 @@ inline std::vector<Agent<TSeq>> & Model<TSeq>::get_agents()
return population;
}

template<typename TSeq>
inline Agent<TSeq> & Model<TSeq>::get_agent(size_t i)
{
return population[i];
}

template<typename TSeq>
inline std::vector< epiworld_fast_uint > Model<TSeq>::get_agents_states() const
{
Expand Down Expand Up @@ -7388,6 +7386,25 @@ inline std::vector<Entity<TSeq>> & Model<TSeq>::get_entities()
return entities;
}

template<typename TSeq>
inline Entity<TSeq> & Model<TSeq>::get_entity(size_t i, int * entity_pos)
{

for (size_t j = 0u; j < entities.size(); ++j)
if (entities[j].get_id() == static_cast<int>(i))
{

if (entity_pos)
*entity_pos = j;

return entities[j];

}

throw std::range_error("The entity with id " + std::to_string(i) + " was not found.");

}

template<typename TSeq>
inline Model<TSeq> & Model<TSeq>::agents_smallworld(
epiworld_fast_uint n,
Expand Down Expand Up @@ -7613,51 +7630,10 @@ template<typename TSeq>
inline void Model<TSeq>::dist_entities()
{

// Starting first infection
int n = size();
std::vector< size_t > idx(n);
for (epiworld_fast_uint e = 0; e < entities.size(); ++e)
for (auto & entity: entities)
{

if (entities_dist_funs[e])
{

entities_dist_funs[e](entities[e], this);

} else {

// Picking how many
int nsampled;
if (prevalence_entity_as_proportion[e])
{
nsampled = static_cast<int>(std::floor(prevalence_entity[e] * size()));
}
else
{
nsampled = static_cast<int>(prevalence_entity[e]);
}

if (nsampled > static_cast<int>(size()))
throw std::range_error("There are only " + std::to_string(size()) +
" individuals in the population. Cannot add the entity to " + std::to_string(nsampled));

Entity<TSeq> & entity = entities[e];

int n_left = n;
std::iota(idx.begin(), idx.end(), 0);
while (nsampled > 0)
{
int loc = static_cast<epiworld_fast_uint>(floor(runif() * n_left--));

population[idx[loc]].add_entity(entity, this, entity.state_init, entity.queue_init);

nsampled--;

std::swap(idx[loc], idx[n_left]);

}

}
entity.distribute();

// Apply the events
events_run();
Expand Down Expand Up @@ -7951,10 +7927,9 @@ inline void Model<TSeq>::add_entity_n(Entity<TSeq> e, epiworld_fast_uint preval)

e.model = this;
e.id = entities.size();
e.prevalence = preval;
e.prevalence_as_proportion = false;
entities.push_back(e);
prevalence_entity.push_back(preval);
prevalence_entity_as_proportion.push_back(false);
entities_dist_funs.push_back(nullptr);

}

Expand All @@ -7964,13 +7939,28 @@ inline void Model<TSeq>::add_entity_fun(Entity<TSeq> e, EntityToAgentFun<TSeq> f

e.model = this;
e.id = entities.size();
e.dist_fun = fun;
entities.push_back(e);
prevalence_entity.push_back(0.0);
prevalence_entity_as_proportion.push_back(false);
entities_dist_funs.push_back(fun);

}

template<typename TSeq>
inline void Model<TSeq>::rm_entity(size_t entity_id)
{

int entity_pos = 0;
auto & entity = this->get_entity(entity_id, &entity_pos);

// First, resetting the entity
entity.reset();

// How should
if (entity_pos != (entities.size() - 1))
std::swap(entities[entity_pos], entities[entities.size() - 1]);

entities.pop_back();
}

template<typename TSeq>
inline void Model<TSeq>::rm_virus(size_t virus_pos)
{
Expand Down Expand Up @@ -12054,8 +12044,13 @@ class Entity {
epiworld_fast_int queue_init = 0; ///< Change of state when added to agent.
epiworld_fast_int queue_post = 0; ///< Change of state when removed from agent.


public:

epiworld_double prevalence = 0.0;
bool prevalence_as_proportion = false;
EntityToAgentFun<TSeq> dist_fun = nullptr;

// Entity() = delete;
// Entity(Entity<TSeq> & e) = delete;
// Entity(const Entity<TSeq> & e);
Expand Down Expand Up @@ -12091,6 +12086,8 @@ class Entity {
bool operator==(const Entity<TSeq> & other) const;
bool operator!=(const Entity<TSeq> & other) const {return !operator==(other);};

void distribute();

};


Expand Down Expand Up @@ -12341,6 +12338,60 @@ inline bool Entity<TSeq>::operator==(const Entity<TSeq> & other) const

}

template<typename TSeq>
inline void Entity<TSeq>::distribute()
{

// Starting first infection
int n = this->model->size();
std::vector< size_t > idx(n);

if (dist_fun)
{

dist_fun(*this, model);

}
else
{

// Picking how many
int nsampled;
if (prevalence_as_proportion)
{
nsampled = static_cast<int>(std::floor(prevalence * size()));
}
else
{
nsampled = static_cast<int>(prevalence);
}

if (nsampled > static_cast<int>(model->size()))
throw std::range_error("There are only " + std::to_string(model->size()) +
" individuals in the population. Cannot add the entity to " + std::to_string(nsampled));

int n_left = n;
std::iota(idx.begin(), idx.end(), 0);
while (nsampled > 0)
{
int loc = static_cast<epiworld_fast_uint>(
floor(model->runif() * n_left--)
);

model->get_agent(idx[loc]).add_entity(
*this, this->model, this->state_init, this->queue_init
);

nsampled--;

std::swap(idx[loc], idx[n_left]);

}

}

}

#endif
/*//////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -13736,7 +13787,7 @@ inline void default_rm_entity(Event<TSeq> & a, Model<TSeq> * m)
p->entities_locations[p->n_entities];

Entity<TSeq> * last_entity =
&m->get_entities()[p->entities[p->n_entities]]; ///< Last entity of the agent
&m->get_entity(p->entities[p->n_entities]); ///< Last entity of the agent

// The end entity will be located where the removed was
last_entity->agents_location[agent_location_in_last_entity] =
Expand Down Expand Up @@ -14112,14 +14163,14 @@ inline void Agent<TSeq>::rm_entity(
"There is entity to remove here!"
);

CHECK_COALESCE_(state_new, model->entities[entity_idx].state_post, state);
CHECK_COALESCE_(queue, model->entities[entity_idx].queue_post, Queue<TSeq>::NoOne);
CHECK_COALESCE_(state_new, model->get_entity(entity_idx).state_post, state);
CHECK_COALESCE_(queue, model->get_entity(entity_idx).queue_post, Queue<TSeq>::NoOne);

model->events_add(
this,
nullptr,
nullptr,
&model->entities[entities[entity_idx]],
&model->get_entity(entity_idx),
state_new,
queue,
default_rm_entity<TSeq>,
Expand Down Expand Up @@ -14591,7 +14642,7 @@ inline const Entity<TSeq> & Agent<TSeq>::get_entity(size_t i) const
if (i >= n_entities)
throw std::range_error("Trying to get to an agent's entity outside of the range.");

return model->entities[entities[i]];
return model->get_entity(entities[i]);
}

template<typename TSeq>
Expand All @@ -14600,7 +14651,7 @@ inline Entity<TSeq> & Agent<TSeq>::get_entity(size_t i)
if (i >= n_entities)
throw std::range_error("Trying to get to an agent's entity outside of the range.");

return model->entities[entities[i]];
return model->get_entity(entities[i]);
}

template<typename TSeq>
Expand Down
2 changes: 1 addition & 1 deletion include/epiworld/agent-events-meat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ inline void default_rm_entity(Event<TSeq> & a, Model<TSeq> * m)
p->entities_locations[p->n_entities];

Entity<TSeq> * last_entity =
&m->get_entities()[p->entities[p->n_entities]]; ///< Last entity of the agent
&m->get_entity(p->entities[p->n_entities]); ///< Last entity of the agent

// The end entity will be located where the removed was
last_entity->agents_location[agent_location_in_last_entity] =
Expand Down
10 changes: 5 additions & 5 deletions include/epiworld/agent-meat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,14 +332,14 @@ inline void Agent<TSeq>::rm_entity(
"There is entity to remove here!"
);

CHECK_COALESCE_(state_new, model->entities[entity_idx].state_post, state);
CHECK_COALESCE_(queue, model->entities[entity_idx].queue_post, Queue<TSeq>::NoOne);
CHECK_COALESCE_(state_new, model->get_entity(entity_idx).state_post, state);
CHECK_COALESCE_(queue, model->get_entity(entity_idx).queue_post, Queue<TSeq>::NoOne);

model->events_add(
this,
nullptr,
nullptr,
&model->entities[entities[entity_idx]],
&model->get_entity(entity_idx),
state_new,
queue,
default_rm_entity<TSeq>,
Expand Down Expand Up @@ -811,7 +811,7 @@ inline const Entity<TSeq> & Agent<TSeq>::get_entity(size_t i) const
if (i >= n_entities)
throw std::range_error("Trying to get to an agent's entity outside of the range.");

return model->entities[entities[i]];
return model->get_entity(entities[i]);
}

template<typename TSeq>
Expand All @@ -820,7 +820,7 @@ inline Entity<TSeq> & Agent<TSeq>::get_entity(size_t i)
if (i >= n_entities)
throw std::range_error("Trying to get to an agent's entity outside of the range.");

return model->entities[entities[i]];
return model->get_entity(entities[i]);
}

template<typename TSeq>
Expand Down
Loading

0 comments on commit d89f393

Please sign in to comment.