Skip to content

Commit

Permalink
Adding tests for the SEIR mixing model
Browse files Browse the repository at this point in the history
  • Loading branch information
gvegayon committed Apr 27, 2024
1 parent 18e11d3 commit 3b999b0
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 7 deletions.
2 changes: 1 addition & 1 deletion include/epiworld/agent-meat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ inline void Agent<TSeq>::rm_entity(
int entity_idx = -1;
for (size_t i = 0u; i < n_entities; ++i)
{
if (entities[i] == entity.get_id())
if (static_cast<int>(entities[i]) == entity.get_id())
{
entity_idx = i;
break;
Expand Down
45 changes: 40 additions & 5 deletions tests/05-mixing.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
// #define EPI_DEBUG
#include "../include/epiworld/epiworld.hpp"
#ifndef CATCH_CONFIG_MAIN
#define EPI_DEBUG
#endif

#include "tests.hpp"

using namespace epiworld;

template<typename TSeq = int>
EntityToAgentFun<TSeq> dist_factory(int from, int to) {
inline EntityToAgentFun<TSeq> dist_factory(int from, int to) {
return [from, to](Entity<TSeq> & e, Model<TSeq> * m) -> void {

auto & agents = m->get_agents();
Expand All @@ -19,7 +22,7 @@ EntityToAgentFun<TSeq> dist_factory(int from, int to) {
}

template<typename TSeq = int>
VirusToAgentFun<TSeq> dist_virus(int i)
inline VirusToAgentFun<TSeq> dist_virus(int i)
{
return [i](Virus<TSeq> & v, Model<TSeq> * m) -> void {

Expand All @@ -30,7 +33,7 @@ VirusToAgentFun<TSeq> dist_virus(int i)

}

int main() {
EPIWORLD_TEST_CASE("SEIRMixing", "[SEIR-dist]") {

std::vector< double > contact_matrix = {
1.0, 0.0, 0.0,
Expand Down Expand Up @@ -88,6 +91,10 @@ int main() {

}

#ifdef CATCH_CONFIG_MAIN
REQUIRE_FALSE((n_wrong != 0 | n_right != 3000));
#endif

// Reruning the model where individuals from group 0 transmit all to group 1
contact_matrix[0] = 0.0;
contact_matrix[6] = 1.0;
Expand Down Expand Up @@ -124,6 +131,34 @@ int main() {

}

#ifdef CATCH_CONFIG_MAIN
REQUIRE_FALSE((n_wrong != 0 | n_right != 3001));
#endif

// Rerunning with plain mixing
std::fill(contact_matrix.begin(), contact_matrix.end(), 1.0/3.0);
model.set_contact_matrix(contact_matrix);

// Running and checking the results
model.run(50, 123);
model.print();

std::vector< int > totals;
model.get_db().get_today_total(nullptr, &totals);

std::vector< int > expected_totals = {
0, 0, 0,
static_cast<int>(model.size())
};

#ifdef CATCH_CONFIG_MAIN
REQUIRE_THAT(totals, Catch::Equals(expected_totals));
#endif



#ifndef CATCH_CONFIG_MAIN
return 0;
#endif

}
3 changes: 2 additions & 1 deletion tests/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
#include "01-sirconnected.cpp"
#include "02-reproducible-sir.cpp"
#include "02-reproducible-sirconn.cpp"
#include "04-initial-dist.cpp"
#include "04-initial-dist.cpp"
#include "05-mixing.cpp"

0 comments on commit 3b999b0

Please sign in to comment.