From 9d83ca7bb22fc22dfb81a2b73b667c6179d22854 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Thu, 30 Mar 2023 10:05:35 -0400 Subject: [PATCH] feat: install sigint handler for run_cli_driver --- src/cpp/main.cpp | 39 ++++++++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/src/cpp/main.cpp b/src/cpp/main.cpp index 5873c44..45611fd 100644 --- a/src/cpp/main.cpp +++ b/src/cpp/main.cpp @@ -44,6 +44,7 @@ #include #include +#include #include #include #include @@ -592,11 +593,24 @@ void py_unsetup_example(VW::workspace& ws, std::vector& ex) for (auto& example : ex) { py_unsetup_example(ws, *example); } } +// Because of the GIL we can use globals here. +static bool SIGINT_CALLED = false; +static VW::workspace* CLI_DRIVER_WORKSPACE = nullptr; + // return type is an optional error information (nullopt if success), driver output, list of log messages // stdin is not supported std::tuple, std::string, std::vector> run_cli_driver( const std::vector& args, bool onethread) { + SIGINT_CALLED = false; + CLI_DRIVER_WORKSPACE = nullptr; + std::signal(SIGINT, + [](int) + { + if (CLI_DRIVER_WORKSPACE != nullptr) { VW::details::set_done(*CLI_DRIVER_WORKSPACE); } + SIGINT_CALLED = true; + }); + auto args_copy = args; args_copy.push_back("--no_stdin"); auto options = VW::make_unique(args_copy); @@ -620,18 +634,23 @@ std::tuple, std::string, std::vector> ru { auto all = VW::initialize_experimental(std::move(options), nullptr, driver_logger, &driver_log, &logger); all->vw_is_main = true; + CLI_DRIVER_WORKSPACE = all.get(); - if (onethread) { VW::LEARNER::generic_driver_onethread(*all); } - else + // If sigint was called before we got here, we should avoid running the driver. + if (!SIGINT_CALLED) { - VW::start_parser(*all); - VW::LEARNER::generic_driver(*all); - VW::end_parser(*all); + if (onethread) { VW::LEARNER::generic_driver_onethread(*all); } + else + { + VW::start_parser(*all); + VW::LEARNER::generic_driver(*all); + VW::end_parser(*all); + } + + if (all->example_parser->exc_ptr) { std::rethrow_exception(all->example_parser->exc_ptr); } + VW::sync_stats(*all); + all->finish(); } - - if (all->example_parser->exc_ptr) { std::rethrow_exception(all->example_parser->exc_ptr); } - VW::sync_stats(*all); - all->finish(); } catch (const std::exception& ex) { @@ -642,6 +661,8 @@ std::tuple, std::string, std::vector> ru return std::make_tuple("Unknown exception occurred", driver_log.str(), log_log); } + SIGINT_CALLED = false; + CLI_DRIVER_WORKSPACE = nullptr; return std::make_tuple(std::nullopt, driver_log.str(), log_log); }