diff --git a/R/pool.r b/R/pool.r index 06d72fc..2a0284e 100644 --- a/R/pool.r +++ b/R/pool.r @@ -34,6 +34,7 @@ Pool = R6::R6Class("Pool", add = function(qsys, n, ...) { self$workers = qsys$new(addr=private$addr, master=private$master, n_jobs=n, ...) + private$master$add_pending_workers(n) }, env = function(...) { @@ -127,7 +128,10 @@ Pool = R6::R6Class("Pool", ), active = list( - workers_total = function() self$workers$n(), + workers_total = function() { + ls_w = private$master$list_workers() + length(ls_w$worker) + ls_w$pending + }, workers_running = function() length(private$master$list_workers()$worker), reusable = function() private$reuse ), diff --git a/src/CMQMaster.cpp b/src/CMQMaster.cpp index cea68d9..72ff2b9 100644 --- a/src/CMQMaster.cpp +++ b/src/CMQMaster.cpp @@ -15,6 +15,7 @@ RCPP_MODULE(cmq_master) { .method("add_env", &CMQMaster::add_env) .method("add_pkg", &CMQMaster::add_pkg) .method("list_env", &CMQMaster::list_env) + .method("add_pending_workers", &CMQMaster::add_pending_workers) .method("list_workers", &CMQMaster::list_workers) ; } diff --git a/src/CMQMaster.h b/src/CMQMaster.h index 7bd14f7..e47914d 100644 --- a/src/CMQMaster.h +++ b/src/CMQMaster.h @@ -44,8 +44,8 @@ class CMQMaster { } SEXP recv(int timeout=-1) { -// if (peers.size() == 0) -// Rf_error("Trying to receive data without workers"); + if (peers.size() + pending_workers <= 0) + Rf_error("Trying to receive data without workers"); int data_offset; std::vector msgs; @@ -156,6 +156,10 @@ class CMQMaster { Rcpp::_["size"] = Rcpp::wrap(sizes)); } + void add_pending_workers(int n) { + pending_workers += n; + } + Rcpp::List list_workers() { std::vector names; names.reserve(peers.size()); @@ -172,7 +176,8 @@ class CMQMaster { Rcpp::_["worker"] = Rcpp::wrap(names), Rcpp::_["status"] = Rcpp::wrap(status), Rcpp::_["time"] = wtime, - Rcpp::_["mem"] = mem + Rcpp::_["mem"] = mem, + Rcpp::_["pending"] = pending_workers ); } @@ -188,6 +193,7 @@ class CMQMaster { zmq::context_t *ctx {nullptr}; int has_proxy {0}; + int pending_workers {0}; zmq::socket_t sock; std::string cur; std::unordered_map peers; @@ -238,7 +244,9 @@ class CMQMaster { ++cur_i; cur = msgs[cur_i].to_string(); + int prev_size = peers.size(); auto &w = peers[cur]; + pending_workers -= peers.size() - prev_size; w.call = R_NilValue; if (cur_i == 1) w.via = msgs[0].to_string(); diff --git a/tests/testthat/test-2-worker.r b/tests/testthat/test-2-worker.r index 2f425b7..081239a 100644 --- a/tests/testthat/test-2-worker.r +++ b/tests/testthat/test-2-worker.r @@ -3,6 +3,7 @@ context("worker usage") test_that("timeouts are triggered correctly", { m = methods::new(CMQMaster) addr = m$listen("inproc://endpoint") + m$add_pending_workers(1L) expect_error(m$recv(0L)) m$close(0L) @@ -15,6 +16,7 @@ test_that("worker evaluation", { m = methods::new(CMQMaster) w = methods::new(CMQWorker, m$context()) addr = m$listen("inproc://endpoint") + m$add_pending_workers(1L) w$connect(addr, 0L) m$recv(0L) @@ -33,6 +35,7 @@ test_that("export variable to worker", { m = methods::new(CMQMaster) w = methods::new(CMQWorker, m$context()) addr = m$listen("inproc://endpoint") + m$add_pending_workers(1L) w$connect(addr, 0L) m$add_env("x", 3) @@ -58,6 +61,7 @@ test_that("load package on worker", { m = methods::new(CMQMaster) w = methods::new(CMQWorker, m$context()) addr = m$listen("inproc://endpoint") + m$add_pending_workers(1L) w$connect(addr, 0L) m$add_pkg("parallel") @@ -80,6 +84,7 @@ test_that("errors are sent back to master", { m = methods::new(CMQMaster) w = methods::new(CMQWorker, m$context()) addr = m$listen("inproc://endpoint") + m$add_pending_workers(1L) w$connect(addr, 0L) m$recv(0L) @@ -100,6 +105,7 @@ test_that("worker R API", { m = methods::new(CMQMaster) addr = m$listen("tcp://127.0.0.1:*") + m$add_pending_workers(1L) # addr = m$listen("inproc://endpoint") # mailbox.cpp assertion error p = parallel::mcparallel(worker(addr)) @@ -120,6 +126,7 @@ test_that("communication with two workers", { m = methods::new(CMQMaster) addr = m$listen("tcp://127.0.0.1:*") + m$add_pending_workers(2L) w1 = parallel::mcparallel(worker(addr)) w2 = parallel::mcparallel(worker(addr))