Skip to content

Commit

Permalink
Collapsing prop info to dist fun
Browse files Browse the repository at this point in the history
  • Loading branch information
gvegayon committed Jun 12, 2024
1 parent 68f8988 commit cb65638
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 204 deletions.
118 changes: 28 additions & 90 deletions epiworld.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12169,8 +12169,6 @@ class Entity {
epiworld_fast_int queue_init = 0; ///< Change of state when added to agent.
epiworld_fast_int queue_post = 0; ///< Change of state when removed from agent.

epiworld_double prevalence = 0.0;
bool prevalence_as_proportion = false;
EntityToAgentFun<TSeq> dist_fun = nullptr;

public:
Expand All @@ -12182,19 +12180,13 @@ class Entity {
* This constructor initializes an Entity object with the specified parameters.
*
* @param name The name of the entity.
* @param preval The prevalence of the entity.
* @param as_proportion A flag indicating whether the prevalence is given as a proportion.
* @param fun A function pointer to a function that maps the entity to an agent.
*/
Entity(
std::string name,
epiworld_double preval,
bool as_proportion,
EntityToAgentFun<TSeq> fun = nullptr
) :
entity_name(name),
prevalence(preval),
prevalence_as_proportion(as_proportion),
dist_fun(fun)
{};

Expand Down Expand Up @@ -12238,10 +12230,6 @@ class Entity {
std::vector< size_t > & get_agents();

void print() const;

void set_prevalence(epiworld_double p, bool as_proportion);
epiworld_double get_prevalence() const noexcept;
bool get_prevalence_as_proportion() const noexcept;
void set_dist_fun(EntityToAgentFun<TSeq> fun);

};
Expand Down Expand Up @@ -12274,33 +12262,54 @@ template <typename TSeq = EPI_DEFAULT_TSEQ>
/**
* Distributes an entity to unassigned agents in the model.
*
* @param prevalence The proportion of agents to distribute the entity to.
* @param as_proportion Flag indicating whether the prevalence is a proportion
* @param to_unassigned Flag indicating whether to distribute the entity only
* to unassigned agents.
* @return An EntityToAgentFun object that distributes the entity to unassigned
* agents.
*/
inline EntityToAgentFun<TSeq> distribute_entity_to_unassigned()
inline EntityToAgentFun<TSeq> distribute_entity_randomly(
epiworld_double prevalence,
bool as_proportion,
bool to_unassigned
)
{

return [](Entity<TSeq> & e, Model<TSeq> * m) -> void {
return [prevalence, as_proportion, to_unassigned](
Entity<TSeq> & e, Model<TSeq> * m
) -> void {


// Preparing the sampling space
std::vector< size_t > idx;
for (const auto & a: m->get_agents())
if (a.get_n_entities() == 0)
if (to_unassigned)
{
for (const auto & a: m->get_agents())
if (a.get_n_entities() == 0)
idx.push_back(a.get_id());
}
else
{

for (const auto & a: m->get_agents())
idx.push_back(a.get_id());

}

size_t n = idx.size();

// Figuring out how many to sample
int n_to_sample;
if (e.get_prevalence_as_proportion())
if (as_proportion)
{
n_to_sample = static_cast<int>(std::floor(e.get_prevalence() * n));
n_to_sample = static_cast<int>(std::floor(prevalence * n));
if (n_to_sample > static_cast<int>(n))
--n_to_sample;

} else
{
n_to_sample = static_cast<int>(e.get_prevalence());
n_to_sample = static_cast<int>(prevalence);
if (n_to_sample > static_cast<int>(n))
throw std::range_error("There are only " + std::to_string(n) +
" individuals in the population. Cannot add the entity to " +
Expand Down Expand Up @@ -12652,61 +12661,12 @@ template<typename TSeq>
inline void Entity<TSeq>::distribute()
{

// Starting first infection
int n = this->model->size();
std::vector< size_t > idx(n);

if (dist_fun)
{

dist_fun(*this, model);

}
else
{

// Picking how many
int n_to_assign;
if (prevalence_as_proportion)
{
n_to_assign = static_cast<int>(std::floor(prevalence * size()));
}
else
{
n_to_assign = static_cast<int>(prevalence);
}

if (n_to_assign > static_cast<int>(model->size()))
throw std::range_error("There are only " + std::to_string(model->size()) +
" individuals in the population. Cannot add the entity to " + std::to_string(n_to_assign));

int n_left = n;
std::iota(idx.begin(), idx.end(), 0);
while ((n_to_assign > 0) && (n_left > 0))
{
int loc = static_cast<epiworld_fast_uint>(
floor(model->runif() * n_left--)
);

// Correcting for possible overflow
if ((loc > 0) && (loc >= n_left))
loc = n_left - 1;

auto & agent = model->get_agent(idx[loc]);

if (!agent.has_entity(id))
{
agent.add_entity(
*this, this->model, this->state_init, this->queue_init
);
n_to_assign--;
}

std::swap(idx[loc], idx[n_left]);

}

}

}

Expand All @@ -12728,28 +12688,6 @@ inline void Entity<TSeq>::print() const
);
}

template<typename TSeq>
inline void Entity<TSeq>::set_prevalence(
epiworld_double p,
bool as_proportion
)
{
prevalence = p;
prevalence_as_proportion = as_proportion;
}

template<typename TSeq>
inline epiworld_double Entity<TSeq>::get_prevalence() const noexcept
{
return prevalence;
}

template<typename TSeq>
inline bool Entity<TSeq>::get_prevalence_as_proportion() const noexcept
{
return prevalence_as_proportion;
}

template<typename TSeq>
inline void Entity<TSeq>::set_dist_fun(EntityToAgentFun<TSeq> fun)
{
Expand Down
6 changes: 3 additions & 3 deletions examples/11-entities/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ int main() {
);

// Creating three groups
Entity<> e1("Entity 1", 0.0, false, dist_factory<>(0, 3000));
Entity<> e2("Entity 2", 0.0, false, dist_factory<>(3000, 6000));
Entity<> e3("Entity 3", 0.0, false, dist_factory<>(6000, 10000));
Entity<> e1("Entity 1", dist_factory<>(0, 3000));
Entity<> e2("Entity 2", dist_factory<>(3000, 6000));
Entity<> e3("Entity 3", dist_factory<>(6000, 10000));

model.add_entity(e1);
model.add_entity(e2);
Expand Down
12 changes: 0 additions & 12 deletions include/epiworld/entity-bones.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ class Entity {
epiworld_fast_int queue_init = 0; ///< Change of state when added to agent.
epiworld_fast_int queue_post = 0; ///< Change of state when removed from agent.

epiworld_double prevalence = 0.0;
bool prevalence_as_proportion = false;
EntityToAgentFun<TSeq> dist_fun = nullptr;

public:
Expand All @@ -72,19 +70,13 @@ class Entity {
* This constructor initializes an Entity object with the specified parameters.
*
* @param name The name of the entity.
* @param preval The prevalence of the entity.
* @param as_proportion A flag indicating whether the prevalence is given as a proportion.
* @param fun A function pointer to a function that maps the entity to an agent.
*/
Entity(
std::string name,
epiworld_double preval,
bool as_proportion,
EntityToAgentFun<TSeq> fun = nullptr
) :
entity_name(name),
prevalence(preval),
prevalence_as_proportion(as_proportion),
dist_fun(fun)
{};

Expand Down Expand Up @@ -128,10 +120,6 @@ class Entity {
std::vector< size_t > & get_agents();

void print() const;

void set_prevalence(epiworld_double p, bool as_proportion);
epiworld_double get_prevalence() const noexcept;
bool get_prevalence_as_proportion() const noexcept;
void set_dist_fun(EntityToAgentFun<TSeq> fun);

};
Expand Down
35 changes: 28 additions & 7 deletions include/epiworld/entity-distribute-meat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,54 @@ template <typename TSeq = EPI_DEFAULT_TSEQ>
/**
* Distributes an entity to unassigned agents in the model.
*
* @param prevalence The proportion of agents to distribute the entity to.
* @param as_proportion Flag indicating whether the prevalence is a proportion
* @param to_unassigned Flag indicating whether to distribute the entity only
* to unassigned agents.
* @return An EntityToAgentFun object that distributes the entity to unassigned
* agents.
*/
inline EntityToAgentFun<TSeq> distribute_entity_to_unassigned()
inline EntityToAgentFun<TSeq> distribute_entity_randomly(
epiworld_double prevalence,
bool as_proportion,
bool to_unassigned
)
{

return [](Entity<TSeq> & e, Model<TSeq> * m) -> void {
return [prevalence, as_proportion, to_unassigned](
Entity<TSeq> & e, Model<TSeq> * m
) -> void {


// Preparing the sampling space
std::vector< size_t > idx;
for (const auto & a: m->get_agents())
if (a.get_n_entities() == 0)
if (to_unassigned)
{
for (const auto & a: m->get_agents())
if (a.get_n_entities() == 0)
idx.push_back(a.get_id());
}
else
{

for (const auto & a: m->get_agents())
idx.push_back(a.get_id());

}

size_t n = idx.size();

// Figuring out how many to sample
int n_to_sample;
if (e.get_prevalence_as_proportion())
if (as_proportion)
{
n_to_sample = static_cast<int>(std::floor(e.get_prevalence() * n));
n_to_sample = static_cast<int>(std::floor(prevalence * n));
if (n_to_sample > static_cast<int>(n))
--n_to_sample;

} else
{
n_to_sample = static_cast<int>(e.get_prevalence());
n_to_sample = static_cast<int>(prevalence);
if (n_to_sample > static_cast<int>(n))
throw std::range_error("There are only " + std::to_string(n) +
" individuals in the population. Cannot add the entity to " +
Expand Down
71 changes: 0 additions & 71 deletions include/epiworld/entity-meat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,61 +230,12 @@ template<typename TSeq>
inline void Entity<TSeq>::distribute()
{

// Starting first infection
int n = this->model->size();
std::vector< size_t > idx(n);

if (dist_fun)
{

dist_fun(*this, model);

}
else
{

// Picking how many
int n_to_assign;
if (prevalence_as_proportion)
{
n_to_assign = static_cast<int>(std::floor(prevalence * size()));
}
else
{
n_to_assign = static_cast<int>(prevalence);
}

if (n_to_assign > static_cast<int>(model->size()))
throw std::range_error("There are only " + std::to_string(model->size()) +
" individuals in the population. Cannot add the entity to " + std::to_string(n_to_assign));

int n_left = n;
std::iota(idx.begin(), idx.end(), 0);
while ((n_to_assign > 0) && (n_left > 0))
{
int loc = static_cast<epiworld_fast_uint>(
floor(model->runif() * n_left--)
);

// Correcting for possible overflow
if ((loc > 0) && (loc >= n_left))
loc = n_left - 1;

auto & agent = model->get_agent(idx[loc]);

if (!agent.has_entity(id))
{
agent.add_entity(
*this, this->model, this->state_init, this->queue_init
);
n_to_assign--;
}

std::swap(idx[loc], idx[n_left]);

}

}

}

Expand All @@ -306,28 +257,6 @@ inline void Entity<TSeq>::print() const
);
}

template<typename TSeq>
inline void Entity<TSeq>::set_prevalence(
epiworld_double p,
bool as_proportion
)
{
prevalence = p;
prevalence_as_proportion = as_proportion;
}

template<typename TSeq>
inline epiworld_double Entity<TSeq>::get_prevalence() const noexcept
{
return prevalence;
}

template<typename TSeq>
inline bool Entity<TSeq>::get_prevalence_as_proportion() const noexcept
{
return prevalence_as_proportion;
}

template<typename TSeq>
inline void Entity<TSeq>::set_dist_fun(EntityToAgentFun<TSeq> fun)
{
Expand Down
Loading

0 comments on commit cb65638

Please sign in to comment.