Skip to content

Commit

Permalink
Merge pull request #3291 from stan-dev/fix/init-err-msgs
Browse files Browse the repository at this point in the history
update error message for different init types
  • Loading branch information
WardBrian authored Nov 14, 2024
2 parents 4ff44b8 + 39b8333 commit 4dc20ff
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 18 deletions.
47 changes: 40 additions & 7 deletions src/stan/services/util/initialize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,19 @@ std::vector<double> initialize(Model& model, const InitContext& init, RNG& rng,
model.transform_inits(context, disc_vector, unconstrained, &msg);
}
} catch (std::domain_error& e) {
if (msg.str().length() > 0)
if (msg.str().length() > 0) {
logger.info(msg);
}
logger.warn("Rejecting initial value:");
logger.warn(
" Error evaluating the log probability"
" at the initial value.");
logger.warn(e.what());
continue;
} catch (std::exception& e) {
if (msg.str().length() > 0)
if (msg.str().length() > 0) {
logger.info(msg);
}
logger.error(
"Unrecoverable error evaluating the log probability"
" at the initial value.");
Expand All @@ -127,8 +129,9 @@ std::vector<double> initialize(Model& model, const InitContext& init, RNG& rng,
// the parameters.
log_prob = model.template log_prob<false, Jacobian>(unconstrained,
disc_vector, &msg);
if (msg.str().length() > 0)
if (msg.str().length() > 0) {
logger.info(msg);
}
} catch (std::domain_error& e) {
if (msg.str().length() > 0)
logger.info(msg);
Expand All @@ -139,8 +142,9 @@ std::vector<double> initialize(Model& model, const InitContext& init, RNG& rng,
logger.warn(e.what());
continue;
} catch (std::exception& e) {
if (msg.str().length() > 0)
if (msg.str().length() > 0) {
logger.info(msg);
}
logger.error(
"Unrecoverable error evaluating the log probability"
" at the initial value.");
Expand All @@ -165,8 +169,9 @@ std::vector<double> initialize(Model& model, const InitContext& init, RNG& rng,
log_prob = stan::model::log_prob_grad<true, Jacobian>(
model, unconstrained, disc_vector, gradient, &log_prob_msg);
} catch (const std::exception& e) {
if (log_prob_msg.str().length() > 0)
if (log_prob_msg.str().length() > 0) {
logger.info(log_prob_msg);
}
logger.error(e.what());
throw;
}
Expand Down Expand Up @@ -210,8 +215,36 @@ std::vector<double> initialize(Model& model, const InitContext& init, RNG& rng,
return unconstrained;
}
}

if (!is_initialized_with_zero) {
if (is_fully_initialized) {
logger.info("");
logger.error("User-specified initialization failed.");
logger.error(
" Try specifying new initial values,"
" using partially specialized initialization,"
" reducing the range of constrained values,"
" or reparameterizing the model.");
} else if (any_initialized) {
logger.info("");
std::stringstream msg;
msg << "Partial user-specified initialization failed. "
"Initialization of non user specified parameters "
"between (-"
<< init_radius << ", " << init_radius << ") failed after"
<< " " << MAX_INIT_TRIES << " attempts. ";
logger.error(msg);
logger.error(
" Try specifying full initial values,"
" reducing the range of constrained values,"
" or reparameterizing the model.");
} else if (is_initialized_with_zero) {
logger.info("");
logger.error("Initial values of 0 failed to initialize.");
logger.error(
" Try specifying new initial values,"
" using partially specialized initialization,"
" reducing the range of constrained values,"
" or reparameterizing the model.");
} else {
logger.info("");
std::stringstream msg;
msg << "Initialization between (-" << init_radius << ", " << init_radius
Expand Down
6 changes: 6 additions & 0 deletions src/test/test-models/good/services/test_fail.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
parameters {
array[2] real<lower=-10, upper=10> y;
}
model {
reject("");
}
17 changes: 17 additions & 0 deletions src/test/unit/services/instrumented_callbacks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,23 @@ class instrumented_logger : public stan::callbacks::logger {
return count;
}

public:
std::vector<std::string> return_all_logs() {
std::vector<std::string> all_logs;
all_logs.reserve(debug_.size() + info_.size() + warn_.size() + error_.size()
+ fatal_.size() + 5);
all_logs.emplace_back("DEBUG");
all_logs.insert(all_logs.end(), debug_.begin(), debug_.end());
all_logs.emplace_back("INFO");
all_logs.insert(all_logs.end(), info_.begin(), info_.end());
all_logs.emplace_back("WARN");
all_logs.insert(all_logs.end(), warn_.begin(), warn_.end());
all_logs.emplace_back("ERROR");
all_logs.insert(all_logs.end(), error_.begin(), error_.end());
all_logs.emplace_back("FATAL");
all_logs.insert(all_logs.end(), fatal_.begin(), fatal_.end());
return all_logs;
}
std::vector<std::string> debug_;
std::vector<std::string> info_;
std::vector<std::string> warn_;
Expand Down
58 changes: 58 additions & 0 deletions src/test/unit/services/util/fail_init_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include <stan/services/util/initialize.hpp>
#include <stan/services/util/create_rng.hpp>
#include <stan/io/empty_var_context.hpp>
#include <stan/io/array_var_context.hpp>
#include <stan/services/util/create_rng.hpp>
#include <stan/callbacks/stream_writer.hpp>
#include <stan/callbacks/stream_logger.hpp>
#include <test/test-models/good/services/test_fail.hpp>
#include <test/unit/util.hpp>
#include <test/unit/services/instrumented_callbacks.hpp>
#include <gtest/gtest.h>
#include <sstream>

class ServicesUtilInitialize : public testing::Test {
public:
ServicesUtilInitialize()
: model(empty_context, 12345, &model_ss),
message(message_ss),
rng(stan::services::util::create_rng(0, 1)) {}

stan_model model;
stan::io::empty_var_context empty_context;
std::stringstream model_ss;
std::stringstream message_ss;
stan::callbacks::stream_writer message;
stan::test::unit::instrumented_logger logger;
stan::test::unit::instrumented_writer init;
stan::rng_t rng;
};

TEST_F(ServicesUtilInitialize, model_throws__full_init) {
std::vector<std::string> names_r;
std::vector<double> values_r;
std::vector<std::vector<size_t> > dim_r;
names_r.push_back("y");
values_r.push_back(6.35149); // 1.5 unconstrained: -10 + 20 * inv.logit(1.5)
values_r.push_back(-2.449187); // -0.5 unconstrained
std::vector<size_t> d;
d.push_back(2);
dim_r.push_back(d);
stan::io::array_var_context init_context(names_r, values_r, dim_r);

double init_radius = 2;
bool print_timing = false;
EXPECT_THROW(
stan::services::util::initialize(model, init_context, rng, init_radius,
print_timing, logger, init),
std::domain_error);
/* Uncomment to print all logs
auto logs = logger.return_all_logs();
for (auto&& m : logs) {
std::cout << m << std::endl;
}
*/
EXPECT_EQ(6, logger.call_count());
EXPECT_EQ(3, logger.call_count_warn());
EXPECT_EQ(0, logger.find_warn("throwing within log_prob"));
}
21 changes: 10 additions & 11 deletions src/test/unit/services/util/initialize_test.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
#include <stan/services/util/initialize.hpp>
#include <stan/services/util/create_rng.hpp>
#include <gtest/gtest.h>
#include <test/unit/util.hpp>
#include <stan/callbacks/stream_writer.hpp>
#include <stan/callbacks/stream_logger.hpp>
#include <sstream>
#include <test/test-models/good/services/test_lp.hpp>
#include <stan/io/empty_var_context.hpp>
#include <stan/io/array_var_context.hpp>
#include <stan/services/util/create_rng.hpp>
#include <stan/callbacks/stream_writer.hpp>
#include <stan/callbacks/stream_logger.hpp>
#include <test/test-models/good/services/test_lp.hpp>
#include <test/unit/util.hpp>
#include <test/unit/services/instrumented_callbacks.hpp>
#include <gtest/gtest.h>
#include <sstream>

class ServicesUtilInitialize : public testing::Test {
public:
Expand All @@ -28,7 +28,7 @@ class ServicesUtilInitialize : public testing::Test {
stan::rng_t rng;
};

TEST_F(ServicesUtilInitialize, radius_zero__print_false) {
TEST_F(ServicesUtilInitialize, radius_zero_print_false) {
std::vector<double> params;

double init_radius = 0;
Expand Down Expand Up @@ -250,7 +250,7 @@ class mock_throwing_model : public stan::model::prob_grad {

} // namespace test

TEST_F(ServicesUtilInitialize, model_throws__radius_zero) {
TEST_F(ServicesUtilInitialize, model_throws_radius_zero) {
test::mock_throwing_model throwing_model;

double init_radius = 0;
Expand All @@ -259,8 +259,7 @@ TEST_F(ServicesUtilInitialize, model_throws__radius_zero) {
stan::services::util::initialize(throwing_model, empty_context, rng,
init_radius, print_timing, logger, init),
std::domain_error);

EXPECT_EQ(3, logger.call_count());
EXPECT_EQ(6, logger.call_count());
EXPECT_EQ(3, logger.call_count_warn());
EXPECT_EQ(1, logger.find_warn("throwing within log_prob"));
}
Expand Down Expand Up @@ -533,7 +532,7 @@ TEST_F(ServicesUtilInitialize, model_throws_in_write_array__radius_zero) {
init_radius, print_timing, logger, init),
std::domain_error);

EXPECT_EQ(3, logger.call_count());
EXPECT_EQ(6, logger.call_count());
EXPECT_EQ(3, logger.call_count_warn());
EXPECT_EQ(1, logger.find_warn("throwing within write_array"));
}
Expand Down

0 comments on commit 4dc20ff

Please sign in to comment.