Skip to content

Commit

Permalink
[bind] bind specified in #48
Browse files Browse the repository at this point in the history
  • Loading branch information
IsaccBarker committed Oct 30, 2024
1 parent e6b0d9b commit 4d730b7
Showing 1 changed file with 150 additions and 3 deletions.
153 changes: 150 additions & 3 deletions src/database.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "database.hpp"
#include "config.hpp"

#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
Expand Down Expand Up @@ -160,17 +161,163 @@ static py::dict get_hist_transition_matrix(DataBase<int> &self,
return ret;
}

static py::dict get_hist_virus(DataBase<int> &self) {
std::vector<std::string> *states = new std::vector<std::string>();
std::vector<int> *dates = new std::vector<int>();
std::vector<int> *ids = new std::vector<int>();
std::vector<int> *counts = new std::vector<int>();

self.get_hist_virus(*dates, *ids, *states, *counts);

/* Return to Python. */
py::capsule pyc_dates(
dates, [](void *x) { delete reinterpret_cast<std::vector<int> *>(x); });
py::capsule pyc_ids(
ids, [](void *x) { delete reinterpret_cast<std::vector<int> *>(x); });
py::capsule pyc_counts(
counts, [](void *x) { delete reinterpret_cast<std::vector<int> *>(x); });

py::array py_dates(dates->size(), dates->data(), pyc_dates);
py::array py_ids(ids->size(), ids->data(), pyc_ids);
py::array py_counts(counts->size(), counts->data(), pyc_counts);

py::dict ret("dates"_a = dates, "ids"_a = py_ids, "states"_a = states,
"counts"_a = py_counts);

return ret;
}

static py::dict get_hist_tool(DataBase<int> &self) {
std::vector<std::string> *states = new std::vector<std::string>();
std::vector<int> *dates = new std::vector<int>();
std::vector<int> *ids = new std::vector<int>();
std::vector<int> *counts = new std::vector<int>();

self.get_hist_tool(*dates, *ids, *states, *counts);

/* Return to Python. */
py::capsule pyc_dates(
dates, [](void *x) { delete reinterpret_cast<std::vector<int> *>(x); });
py::capsule pyc_ids(
ids, [](void *x) { delete reinterpret_cast<std::vector<int> *>(x); });
py::capsule pyc_counts(
counts, [](void *x) { delete reinterpret_cast<std::vector<int> *>(x); });

py::array py_dates(dates->size(), dates->data(), pyc_dates);
py::array py_ids(ids->size(), ids->data(), pyc_ids);
py::array py_counts(counts->size(), counts->data(), pyc_counts);

py::dict ret("dates"_a = dates, "ids"_a = py_ids, "states"_a = states,
"counts"_a = py_counts);

return ret;
}

static py::dict get_today_transition_matrix(DataBase<int> &self) {
std::vector<int> *counts = new std::vector<int>();

self.get_today_transition_matrix(*counts);

/* Return to Python. */
py::capsule pyc_counts(
counts, [](void *x) { delete reinterpret_cast<std::vector<int> *>(x); });

py::array py_counts(counts->size(), counts->data(), pyc_counts);

py::dict ret("counts"_a = py_counts);

return ret;
}

static py::dict get_today_virus(DataBase<int> &self) {
std::vector<std::string> *states = new std::vector<std::string>();
std::vector<int> *ids = new std::vector<int>();
std::vector<int> *counts = new std::vector<int>();

self.get_today_virus(*states, *ids, *counts);

/* Return to Python. */
py::capsule pyc_ids(
ids, [](void *x) { delete reinterpret_cast<std::vector<int> *>(x); });
py::capsule pyc_counts(
counts, [](void *x) { delete reinterpret_cast<std::vector<int> *>(x); });

py::array py_ids(ids->size(), ids->data(), pyc_ids);
py::array py_counts(counts->size(), counts->data(), pyc_counts);

py::dict ret("states"_a = states, "ids"_a = py_ids, "counts"_a = py_counts);

return ret;
}

static py::dict get_today_total(DataBase<int> &self) {
std::vector<std::string> *states = new std::vector<std::string>();
std::vector<int> *counts = new std::vector<int>();

self.get_today_total(states, counts);

/* Return to Python. */
py::capsule pyc_counts(
counts, [](void *x) { delete reinterpret_cast<std::vector<int> *>(x); });

py::array py_counts(counts->size(), counts->data(), pyc_counts);

py::dict ret("states"_a = states, "counts"_a = py_counts);

return ret;
}

void epiworldpy::export_database(
py::class_<epiworld::DataBase<int>,
std::shared_ptr<epiworld::DataBase<int>>> &c) {
c.def("get_hist_total", &get_hist_total,
"Get historical totals for this model run.")
c.def("add_user_data",
pybind11::detail::overload_cast_impl<std::vector<epiworld_double>>()(
&epiworld::DataBase<int>::add_user_data),
"Add a list of user data.")
.def("add_user_data",
pybind11::detail::overload_cast_impl<epiworld_fast_uint,
epiworld_double>()(
&epiworld::DataBase<int>::add_user_data),
"Add a list of user data.")
.def("get_n_tools", &epiworld::DataBase<int>::get_n_tools,
"Get the number of tools.")
.def("get_n_viruses", &epiworld::DataBase<int>::get_n_viruses,
"Get the number of viruses.")
.def("record_transmission", &epiworld::DataBase<int>::record_transmission,
"Record a transmission event.")
.def("write_data", &epiworld::DataBase<int>::write_data,
"Write some data.")
.def("get_hist_virus", &get_hist_virus, "Get historical virus data.")
.def("get_hist_tool", &get_hist_tool, "Get historical tool data.")
.def("get_today_transition_matrix", &get_today_transition_matrix,
"Get today's transition matrix.")
.def("get_today_virus", &get_today_virus, "Get today's virus data.")
//.def("get_today_total", pybind11::detail::overload_cast_impl<epiworld_fast_uint>()(&epiworld::DataBase<int>::get_today_total), "Get today's total data.")
.def("get_today_total", &get_today_total, "Get today's total data.")
.def("size", &epiworld::DataBase<int>::size,
"Get the size (number of entries).")
.def("record", &epiworld::DataBase<int>::record,
"Register a new variant.") // ?
.def("reset", &epiworld::DataBase<int>::reset, "Reset the database.") // ?
/* FIXME: This should work, as PyBind should automatically convert the
* lambda into a py::cpp_function, but I should double check this. */
.def("set_seq_hasher", &epiworld::DataBase<int>::set_seq_hasher,
"Set the sequence hashing function.")
.def("record_tool", &epiworld::DataBase<int>::record_tool,
"Add a new tool to the database.")
.def("record_virus", &epiworld::DataBase<int>::record_virus,
"Add a new virus to the database.")
.def("get_hist_total", &get_hist_total,
"Get historical totals for this model run.")
.def("get_reproductive_number", &get_reproductive_number,
"Get reproductive numbers over time for every virus in the model.")
.def("get_transmissions", &get_transmissions,
"Get transmission data over time for every virus in the model.")
.def("get_generation_time", &get_generation_time,
"Get generation times over time for every virus in the model.")
.def("get_hist_transition_matrix", &get_hist_transition_matrix,
"Get historical transitions in a tabular format.");
"Get historical transitions in a tabular format.")
.def("get_transition_probability",
&epiworld::DataBase<int>::transition_probability,
"Get the transition probably.");
}

0 comments on commit 4d730b7

Please sign in to comment.