From 73bdb1f7a9612410a55db67879db417ad2b64ac7 Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Tue, 19 Nov 2019 22:49:25 +0000 Subject: [PATCH 01/62] Return exit code 15 (SIGTERM) after SIGTERM. When marian receives signal SIGTERM and exits gracefully (save model & exit), it should then exit with a non-zero exit code, to signal to any parent process that it did not exit "naturally". --- src/command/marian_train.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/command/marian_train.cpp b/src/command/marian_train.cpp index 5d953243b..bddda5c74 100644 --- a/src/command/marian_train.cpp +++ b/src/command/marian_train.cpp @@ -68,5 +68,5 @@ int main(int argc, char** argv) { } } - return 0; + return getSigtermFlag() ? 15 : 0; } From 653b13d687cf0ece3823b11fd07accf1574edbc4 Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Fri, 22 Nov 2019 22:59:38 +0000 Subject: [PATCH 02/62] Added explanatory comment about exiting marian_train with non-zero status after SIGTERM. --- src/command/marian_train.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/command/marian_train.cpp b/src/command/marian_train.cpp index bddda5c74..9d312bbb8 100644 --- a/src/command/marian_train.cpp +++ b/src/command/marian_train.cpp @@ -1,3 +1,4 @@ +#include #include "marian.h" #include "training/graph_group_async.h" @@ -68,5 +69,13 @@ int main(int argc, char** argv) { } } - return getSigtermFlag() ? 15 : 0; + // If we exit due to SIGTERM, exit with 128 + the signal number, as suggested + // for bash in http://tldp.org/LDP/abs/html/exitcodes.html. This allows parent + // scripts to determine if training terminated naturally or via SIGTERM. + // Whith this approach we can accommodate additional signals in the future. + // An alternative would be to return 124, which is what the timeout command + // returns for timeout -s SIGTERM ...., because exiting after SIGTERM + // is not technically a fatal error (which is what the 128+x convention usually + // stands for). + return getSigtermFlag() ? (128 + SIGTERM) : 0; } From 2586af7c1628826a806c4d61a47c0fbc8bd0f599 Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Mon, 6 Apr 2020 12:49:23 +0100 Subject: [PATCH 03/62] Bug fix: better handling of SIGTERM for graceful shutdown during training. Prior to this bug fix, BatchGenerator::fetchBatches, which runs in a separate thread, would ignore SIGTERM during training (training uses a custom signal handler for SIGTERM, which simply sets a global flag, to enable graceful shutdown (i.e., save models and current state of training before shutting down). The changes in this commit also facilitate custom handling of other signals in the future by providing a general singal handler for all signals with a signal number below 32 (setSignalFlag) and a generic flag checking function (getSignalFlag(sig)) for checking such flags. --- CHANGELOG.md | 2 ++ src/CMakeLists.txt | 6 ++--- src/command/marian_train.cpp | 3 ++- src/common/signal_handling.cpp | 21 +++++++++++++++++ src/common/signal_handling.h | 27 +++++++++++++++++++++ src/data/batch_generator.h | 23 +++++++++++++----- src/training/scheduler.cpp | 43 ---------------------------------- src/training/scheduler.h | 11 ++++----- src/training/training.h | 6 +++++ 9 files changed, 82 insertions(+), 60 deletions(-) create mode 100644 src/common/signal_handling.cpp create mode 100644 src/common/signal_handling.h delete mode 100644 src/training/scheduler.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index b01d797e0..3e62a9c00 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ### Changed - Make cublas and cusparse handle inits lazy to save memory when unused +- Improved handling for graceful shutdown upon receiving SIGTERM. + SIGTERM now also interrupts batch prefetching, which runs in a separate thread. ## [1.9.0] - 2020-03-10 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6d5a0b1f2..90e710717 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -24,8 +24,9 @@ add_library(marian STATIC common/io.cpp common/filesystem.cpp common/file_stream.cpp + common/signal_handling.cpp common/types.cpp - + data/alignment.cpp data/vocab.cpp data/default_vocab.cpp @@ -92,7 +93,6 @@ add_library(marian STATIC training/graph_group_multinode_sync.cpp training/validator.cpp training/communicator.cpp - training/scheduler.cpp # this is only compiled to catch build errors, but not linked microsoft/quicksand.cpp @@ -139,7 +139,7 @@ cuda_add_library(marian_cuda tensors/gpu/algorithm.cu tensors/gpu/prod.cpp tensors/gpu/element.cu - tensors/gpu/add.cu + tensors/gpu/add.cu tensors/gpu/add_all.cu tensors/gpu/tensor_operators.cu tensors/gpu/cudnn_wrappers.cu diff --git a/src/command/marian_train.cpp b/src/command/marian_train.cpp index c5aa8d298..f6c245f8d 100644 --- a/src/command/marian_train.cpp +++ b/src/command/marian_train.cpp @@ -1,6 +1,7 @@ #include #include "marian.h" +#include "common/signal_handling.h" #include "training/graph_group_async.h" #include "training/graph_group_multinode_sync.h" #include "training/graph_group_singleton.h" @@ -77,5 +78,5 @@ int main(int argc, char** argv) { // returns for timeout -s SIGTERM ...., because exiting after SIGTERM // is not technically a fatal error (which is what the 128+x convention usually // stands for). - return getSigtermFlag() ? (128 + SIGTERM) : 0; + return getSignalFlag(SIGTERM) ? (128 + SIGTERM) : 0; } diff --git a/src/common/signal_handling.cpp b/src/common/signal_handling.cpp new file mode 100644 index 000000000..a18d1e669 --- /dev/null +++ b/src/common/signal_handling.cpp @@ -0,0 +1,21 @@ +#include "common/logging.h" +#include "signal_handling.h" + +// We use signal() here instead of the usual strong recommendation for +// using sigaction, which apparently is not available for Windows (cf. +// https://stackoverflow.com/questions/231912/what-is-the-difference-between-sigaction-and-signal). + +namespace marian{ +volatile std::sig_atomic_t sigflags_{0}; + +bool getSignalFlag(const int sig) { + // sig_atomic_t has 32 bits. We don't accommodate signals beyond that. + ABORT_IF(sig >= 32, "Signal {} out of range (must be < 32).", sig); + return sigflags_ & (1< + +// SIGNAL HANDLING + +// The Marian signal handlers set global flags that thread can +// consider when a signal is received. This can be used for a graceful +// shutdown instead of a hard abandonment, e.g. after receiving +// SIGTERM during training. + +// When SIGTERM is received, the global (static member) flag sigterm_ +// (false by default) is set to true by signalHandler(). When sigterm_ +// is true, keepGoing() returns false, and the current state of +// training models is saved prior to exiting. This functionality is +// helpful when training on clusters with time limits on compute +// slots, e.g., on s clusters managed by slurm. Slurm can be asked to +// sending a (custom) warning signal to a process at a given point in +// time prior to the hard "time's up". +// +// Correspondingly, fetchBatches in the batch generator checks the flag +// frequently and quits after the overall process receives a SIGTERM. + + +namespace marian { +bool getSignalFlag(int sig); // return true if sig was received, false otherwise +void setSignalFlag(int sig); // set custom handler (set flag) for sig +} diff --git a/src/data/batch_generator.h b/src/data/batch_generator.h index f16a7a81c..1a26baa04 100644 --- a/src/data/batch_generator.h +++ b/src/data/batch_generator.h @@ -1,6 +1,7 @@ #pragma once #include "common/options.h" +#include "common/signal_handling.h" #include "data/batch_stats.h" #include "data/rng_engine.h" #include "training/training_state.h" @@ -132,8 +133,14 @@ class BatchGenerator : public RNGEngine { if(current_ != data_->end()) ++current_; } + + std::deque tempBatches; + size_t sets = 0; while(current_ != data_->end() && maxiBatch->size() < maxSize) { // loop over data + if (getSignalFlag(SIGTERM)) { // received SIGTERM, abandon ship ... + return tempBatches; + } maxiBatch->push(*current_); sets = current_->size(); // do not consume more than required for the maxi batch as this causes @@ -149,8 +156,6 @@ class BatchGenerator : public RNGEngine { size_t currentWords = 0; std::vector lengths(sets, 0); // records maximum length observed within current batch - std::deque tempBatches; - // process all loaded sentences in order of increasing length // @TODO: we could just use a vector and do a sort() here; would make the cost more explicit const size_t mbWords = options_->get("mini-batch-words", 0); @@ -158,7 +163,13 @@ class BatchGenerator : public RNGEngine { BatchStats::const_iterator cachedStatsIter; if (stats_) cachedStatsIter = stats_->begin(); + while(!maxiBatch->empty()) { // while there are sentences in the queue + + if (getSignalFlag(SIGTERM)) { // received SIGTERM, abandon ship ... + return tempBatches; + } + // push item onto batch batchVector.push_back(maxiBatch->top()); maxiBatch->pop(); // fetch next-shortest @@ -242,13 +253,13 @@ class BatchGenerator : public RNGEngine { ABORT_IF(!futureBufferedBatches_.valid(), "Attempted to wait for futureBufferedBatches_ when none pending.\n" "This error often occurs when Marian tries to restore the training data iterator, but the corpus has been changed or replaced.\n" "If you have changed the training corpus, add --no-restore-corpus to the training command and run it again."); + bufferedBatches_ = std::move(futureBufferedBatches_.get()); - // if bg thread returns an empty swath, we hit the end of the epoch - if (bufferedBatches_.empty()) { + if (bufferedBatches_.empty() // i.e., end of Epoch + || getSignalFlag(SIGTERM)) { // process received SIGTERM, abandon ship ... return nullptr; } - // and kick off the next bg operation - fetchBatchesAsync(); + fetchBatchesAsync(); // pre-fetch next slew of batches in separate thread } auto batch = bufferedBatches_.front(); bufferedBatches_.pop_front(); diff --git a/src/training/scheduler.cpp b/src/training/scheduler.cpp deleted file mode 100644 index 4c30cb04e..000000000 --- a/src/training/scheduler.cpp +++ /dev/null @@ -1,43 +0,0 @@ -#include "scheduler.h" -#include -#include - -namespace marian { - -// SIGNAL HANDLING, see scheduler.cpp for definitions -// Currently, only the following is handled by a custom signal handler: -// SIGTERM: When SIGTERM is received, the global (static member) flag sigterm_ (false by default) is set to true -// by signalHandler(). When sigterm_ is true, keepGoing() returns false, and the current state of training models -// is saved prior to exiting. -// This functionality is helpful when training on clusters with time limits on compute slots, e.g., on s -// clusters managed by slurm. Slurm can be asked to sending a (custom) warning signal to a process at a given -// point in time prior to the hard "time's up". - -bool sigterm_{false}; // flag signalling that SIGTERM has been received false by default, set to true by signalHandler(SIGTERM) - -void signalHandler(int sig) { - // Note: sys_siglist[sig] or stdsignal() describe the effect (e.g., - // 'Terminated' rather than provide the signal name (which are #define(s) - // in signal.h), so we have to do custom log messages here. - switch (sig) { - case SIGTERM: // save models and exit - LOG(info, "[training] Scheduler received signal SIGTERM"); // @TODO: figure out if this is safe. The logs are global and thread-safe, so should be OK? - sigterm_ = true; - break; - default: - ABORT("No action defined for signal {}", sig); - } -} - -// installs signalHandler() for select signals (currently only SIGTERM) -void installSignalHandlers() { - // TODO: use sigaction instead of signal, - // cf. https://stackoverflow.com/questions/231912/what-is-the-difference-between-sigaction-and-signal - signal(SIGTERM, signalHandler); -} - -bool getSigtermFlag() { - return sigterm_; -} - -} diff --git a/src/training/scheduler.h b/src/training/scheduler.h index 7e601632c..651c34b31 100755 --- a/src/training/scheduler.h +++ b/src/training/scheduler.h @@ -1,6 +1,7 @@ #pragma once #include "common/options.h" +#include "common/signal_handling.h" #include "training/training_state.h" #include "training/validator.h" #include "training/communicator.h" @@ -8,9 +9,6 @@ namespace marian { -bool getSigtermFlag(); -void installSignalHandlers(); - class Scheduler : public TrainingObserver { private: Ptr options_; @@ -149,12 +147,11 @@ class Scheduler : public TrainingObserver { : options_(options), state_(state) { ABORT_IF(state_->factor != 1, "state.factor unexpectedly not 1 at this point??"); updateLearningRate(*state); - installSignalHandlers(); } bool keepGoing() { - if(getSigtermFlag()) // received signal SIGERM => exit gracefully + if(getSignalFlag(SIGTERM)) // received signal SIGERM => exit gracefully return false; // stop if it reached the maximum number of epochs @@ -184,7 +181,7 @@ class Scheduler : public TrainingObserver { void started() { LOG(info, "Training started"); } void finished() { - if (getSigtermFlag()) + if (getSignalFlag(SIGTERM)) LOG(info, "Training interrupted (SIGTERM)."); else LOG(info, "Training finished"); @@ -217,7 +214,7 @@ class Scheduler : public TrainingObserver { bool isFinal = false) { // Do not validate if already validated (for instance, after the model is // loaded) or if validation is scheduled for another update, or when signal SIGTERM was received - if(getSigtermFlag() // SIGTERM was received + if(getSignalFlag(SIGTERM) // SIGTERM was received || state_->validated // already validated (in resumed training, for example) || (!state_->enteredNewPeriodOf(options_->get("valid-freq")) && !isFinal)) // not now return; diff --git a/src/training/training.h b/src/training/training.h index 5a2be7635..c68602ec9 100644 --- a/src/training/training.h +++ b/src/training/training.h @@ -77,6 +77,12 @@ class Train : public ModelTask { bool restored = !options_->get("no-restore-corpus") && batchGenerator->restore(trainState); + // Install custom handler for SIGTERM, to allow for a graceful + // shutdown that saves the current state of training before exiting. + // This signal handler simply sets a flag that can be checked from + // everywhere (getSignalFLAG(SIGTERM); #include common/signal_handling.h) + signal(SIGTERM,setSignalFlag); + // -- main training loop scheduler->started(); while(scheduler->keepGoing()) { From 66711b515769fede145335d270597a883d097285 Mon Sep 17 00:00:00 2001 From: Martin Junczys-Dowmunt Date: Sat, 14 Mar 2020 00:07:37 +0000 Subject: [PATCH 04/62] Merged PR 11929: Move around code to make later comparison with FP16 code easier This does not introduce any new functionality, just moves code around, so that future PRs are easier to compare. Moving old GraphGroup code to training/deprecated. Once it is clear there is nothing in there that's worth saving, this will be deleted. Replace -Ofast with -O3 and make sure ffinite-math is turned off. --- CMakeLists.txt | 8 +- src/CMakeLists.txt | 6 +- src/command/marian_train.cpp | 34 +--- src/tensors/cpu/fbgemm/expanded_gemm.h | 3 +- .../gradient_dropping/dropper.h | 0 .../gradient_dropping/gpu/dropper.cu | 0 .../gradient_dropping/gpu/sparse_algorithm.cu | 0 .../gradient_dropping/gpu/sparse_algorithm.h | 0 .../gradient_dropping/sparse_tensor.h | 0 .../graph_group_async_drop.cpp | 0 .../{ => deprecated}/graph_group_async_drop.h | 0 .../graph_group_multinode.cpp | 0 .../{ => deprecated}/graph_group_multinode.h | 0 .../graph_group_multinode_sync.cpp | 0 .../graph_group_multinode_sync.h | 0 src/training/graph_group.cpp | 89 ++++++++++ src/training/graph_group.h | 167 +----------------- 17 files changed, 107 insertions(+), 200 deletions(-) rename src/training/{ => deprecated}/gradient_dropping/dropper.h (100%) rename src/training/{ => deprecated}/gradient_dropping/gpu/dropper.cu (100%) rename src/training/{ => deprecated}/gradient_dropping/gpu/sparse_algorithm.cu (100%) rename src/training/{ => deprecated}/gradient_dropping/gpu/sparse_algorithm.h (100%) rename src/training/{ => deprecated}/gradient_dropping/sparse_tensor.h (100%) rename src/training/{ => deprecated}/graph_group_async_drop.cpp (100%) rename src/training/{ => deprecated}/graph_group_async_drop.h (100%) rename src/training/{ => deprecated}/graph_group_multinode.cpp (100%) rename src/training/{ => deprecated}/graph_group_multinode.h (100%) rename src/training/{ => deprecated}/graph_group_multinode_sync.cpp (100%) rename src/training/{ => deprecated}/graph_group_multinode_sync.h (100%) create mode 100644 src/training/graph_group.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 46d9c6c91..4cf8bf922 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -167,9 +167,9 @@ else(MSVC) endif(CMAKE_COMPILER_IS_GNUCC) set(CMAKE_CXX_FLAGS "-std=c++11 -pthread ${CMAKE_GCC_FLAGS} -fPIC ${DISABLE_GLOBALLY} -march=${BUILD_ARCH} ${INTRINSICS}") - set(CMAKE_CXX_FLAGS_RELEASE "-Ofast -m64 -funroll-loops -ffinite-math-only -g ${CMAKE_RDYNAMIC_FLAG}") + set(CMAKE_CXX_FLAGS_RELEASE "-O3 -m64 -funroll-loops -g ${CMAKE_RDYNAMIC_FLAG}") set(CMAKE_CXX_FLAGS_DEBUG "-O0 -g ${CMAKE_RDYNAMIC_FLAG}") - set(CMAKE_CXX_FLAGS_SLIM "-Ofast -m64 -funroll-loops -ffinite-math-only -DNDEBUG") + set(CMAKE_CXX_FLAGS_SLIM "-O3 -m64 -funroll-loops -DNDEBUG") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELEASE}") set(CMAKE_CXX_FLAGS_PROFILE "${CMAKE_CXX_FLAGS_RELEASE} -pg") set(CMAKE_CXX_FLAGS_PROFGEN "${CMAKE_CXX_FLAGS_RELEASE} -fprofile-generate -fprofile-correction") @@ -177,9 +177,9 @@ else(MSVC) # these need to be set separately set(CMAKE_C_FLAGS "-pthread ${CMAKE_GCC_FLAGS} -fPIC ${DISABLE_GLOBALLY} -march=${BUILD_ARCH} ${INTRINSICS}") - set(CMAKE_C_FLAGS_RELEASE "-O3 -m64 -funroll-loops -ffinite-math-only -g ${CMAKE_RDYNAMIC_FLAG}") + set(CMAKE_C_FLAGS_RELEASE "-O3 -m64 -funroll-loops -g ${CMAKE_RDYNAMIC_FLAG}") set(CMAKE_C_FLAGS_DEBUG "-O0 -g ${CMAKE_RDYNAMIC_FLAG}") - set(CMAKE_C_FLAGS_SLIM "-O3 -m64 -funroll-loops -ffinite-math-only -DNDEBUG") + set(CMAKE_C_FLAGS_SLIM "-O3 -m64 -funroll-loops -DNDEBUG") set(CMAKE_C_FLAGS_RELWITHDEBINFO "${CMAKE_C_FLAGS_RELEASE}") set(CMAKE_C_FLAGS_PROFILE "${CMAKE_C_FLAGS_RELEASE} -pg") set(CMAKE_C_FLAGS_PROFGEN "${CMAKE_C_FLAGS_RELEASE} -fprofile-generate -fprofile-correction") diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 90e710717..acae400fd 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -86,11 +86,9 @@ add_library(marian STATIC translator/scorers.cpp training/graph_group_async.cpp - training/graph_group_async_drop.cpp training/graph_group_sync.cpp + training/graph_group.cpp training/graph_group_singleton.cpp - training/graph_group_multinode.cpp - training/graph_group_multinode_sync.cpp training/validator.cpp training/communicator.cpp @@ -145,8 +143,6 @@ cuda_add_library(marian_cuda tensors/gpu/cudnn_wrappers.cu translator/nth_element.cu translator/helpers.cu - training/gradient_dropping/gpu/dropper.cu - training/gradient_dropping/gpu/sparse_algorithm.cu STATIC) target_compile_options(marian_cuda PUBLIC ${ALL_WARNINGS}) diff --git a/src/command/marian_train.cpp b/src/command/marian_train.cpp index f6c245f8d..46bd05e84 100644 --- a/src/command/marian_train.cpp +++ b/src/command/marian_train.cpp @@ -3,16 +3,10 @@ #include "common/signal_handling.h" #include "training/graph_group_async.h" -#include "training/graph_group_multinode_sync.h" #include "training/graph_group_singleton.h" #include "training/graph_group_sync.h" #include "training/training.h" -#ifdef CUDA_FOUND -#include "training/graph_group_async_drop.h" -#include "training/graph_group_multinode.h" -#endif - #include "3rd_party/ExceptionWithCallStack.h" int main(int argc, char** argv) { @@ -28,18 +22,7 @@ int main(int argc, char** argv) { // MultiNodeGraphGroupSync. if(options->get("multi-node")) { LOG(warn, "[experimental] Using old multi-node training implementations that are not up-to-date"); - - if(options->get("sync-sgd")) { - LOG(info, "[training] Using multi-node synchronous training"); - New>(options)->run(); - } else { -#ifdef CUDA_FOUND - LOG(info, "[training] Using multi-node asynchronous training"); - New>(options)->run(); -#else - ABORT("Asynchronous multi-node training requires CUDA"); -#endif - } + ABORT("Old multi-node training code disabled"); } // --sync-sgd always selects SyncGraphGroup // @@ -47,7 +30,7 @@ int main(int argc, char** argv) { // processes x (single, multiple) GPUs per MPI process. This variant is presently up-to-date and // best supported. else if (options->get("sync-sgd")) { - LOG(info, "[training] Using synchronous training"); + LOG(info, "Using synchronous SGD"); New>(options)->run(); } else { @@ -56,17 +39,8 @@ int main(int argc, char** argv) { LOG(info, "[training] Using single-device training"); New>(options)->run(); } else { - if(options->get("grad-dropping-rate") > 0.0) { -#ifdef CUDA_FOUND - LOG(info, "[training] Using asynchronous training with gradient dropping"); - New>(options)->run(); -#else - ABORT("Asynchronous training with gradient dropping requires CUDA"); -#endif - } else { - LOG(info, "[training] Using asynchronous training"); - New>(options)->run(); - } + LOG(info, "Using asynchronous training"); + New>(options)->run(); } } diff --git a/src/tensors/cpu/fbgemm/expanded_gemm.h b/src/tensors/cpu/fbgemm/expanded_gemm.h index 32cc6b122..38c543c75 100644 --- a/src/tensors/cpu/fbgemm/expanded_gemm.h +++ b/src/tensors/cpu/fbgemm/expanded_gemm.h @@ -123,7 +123,7 @@ struct FbgemmPacked16PackNodeOp : public UnaryNodeOp { #endif // USE_FBGEMM } }; - ; + // Pack a matrix (int8) into cache utilization efficient way (block format) together with quantization into int8 // PackMatrix packMat_: the type of packed matrix - A or B matrix // marian::Type packType_: the type the input matrix is packed - packed8avx2 or packed8avx512 @@ -132,7 +132,6 @@ struct FbgemmPacked16PackNodeOp : public UnaryNodeOp { // int ncol_: the number of columns // uint64_t packsize_: the size of the packed matrix // (the size of int8 packed B from fbgemm:PackAWithQuantRowOffset + quantization scale, offset and zero point) - struct FbgemmPacked8PackNodeOp : public UnaryNodeOp { PackMatrix packMat_; marian::Type packType_; diff --git a/src/training/gradient_dropping/dropper.h b/src/training/deprecated/gradient_dropping/dropper.h similarity index 100% rename from src/training/gradient_dropping/dropper.h rename to src/training/deprecated/gradient_dropping/dropper.h diff --git a/src/training/gradient_dropping/gpu/dropper.cu b/src/training/deprecated/gradient_dropping/gpu/dropper.cu similarity index 100% rename from src/training/gradient_dropping/gpu/dropper.cu rename to src/training/deprecated/gradient_dropping/gpu/dropper.cu diff --git a/src/training/gradient_dropping/gpu/sparse_algorithm.cu b/src/training/deprecated/gradient_dropping/gpu/sparse_algorithm.cu similarity index 100% rename from src/training/gradient_dropping/gpu/sparse_algorithm.cu rename to src/training/deprecated/gradient_dropping/gpu/sparse_algorithm.cu diff --git a/src/training/gradient_dropping/gpu/sparse_algorithm.h b/src/training/deprecated/gradient_dropping/gpu/sparse_algorithm.h similarity index 100% rename from src/training/gradient_dropping/gpu/sparse_algorithm.h rename to src/training/deprecated/gradient_dropping/gpu/sparse_algorithm.h diff --git a/src/training/gradient_dropping/sparse_tensor.h b/src/training/deprecated/gradient_dropping/sparse_tensor.h similarity index 100% rename from src/training/gradient_dropping/sparse_tensor.h rename to src/training/deprecated/gradient_dropping/sparse_tensor.h diff --git a/src/training/graph_group_async_drop.cpp b/src/training/deprecated/graph_group_async_drop.cpp similarity index 100% rename from src/training/graph_group_async_drop.cpp rename to src/training/deprecated/graph_group_async_drop.cpp diff --git a/src/training/graph_group_async_drop.h b/src/training/deprecated/graph_group_async_drop.h similarity index 100% rename from src/training/graph_group_async_drop.h rename to src/training/deprecated/graph_group_async_drop.h diff --git a/src/training/graph_group_multinode.cpp b/src/training/deprecated/graph_group_multinode.cpp similarity index 100% rename from src/training/graph_group_multinode.cpp rename to src/training/deprecated/graph_group_multinode.cpp diff --git a/src/training/graph_group_multinode.h b/src/training/deprecated/graph_group_multinode.h similarity index 100% rename from src/training/graph_group_multinode.h rename to src/training/deprecated/graph_group_multinode.h diff --git a/src/training/graph_group_multinode_sync.cpp b/src/training/deprecated/graph_group_multinode_sync.cpp similarity index 100% rename from src/training/graph_group_multinode_sync.cpp rename to src/training/deprecated/graph_group_multinode_sync.cpp diff --git a/src/training/graph_group_multinode_sync.h b/src/training/deprecated/graph_group_multinode_sync.h similarity index 100% rename from src/training/graph_group_multinode_sync.h rename to src/training/deprecated/graph_group_multinode_sync.h diff --git a/src/training/graph_group.cpp b/src/training/graph_group.cpp new file mode 100644 index 000000000..8950521cd --- /dev/null +++ b/src/training/graph_group.cpp @@ -0,0 +1,89 @@ +#include "training/graph_group.h" + +namespace marian { + +GraphGroup::GraphGroup(Ptr options) : options_(options), opt_(Optimizer(options)) {} + +void GraphGroup::validate() { + ABORT_IF(finalized_, "Training has already finished."); +} + +void GraphGroup::finalize() { + finalized_ = true; +} + +Ptr GraphGroup::collectStats(Ptr graph, + Ptr model, + const std::vector>& vocabs, + double multiplier) { + auto stats = New(); + + size_t numFiles = options_->get>("train-sets").size(); + + // Initialize first batch to step size + size_t first = options_->get("mini-batch-fit-step"); + + // Increase batch size and sentence length by this step size + size_t step = options_->get("mini-batch-fit-step"); + + size_t maxLength = options_->get("max-length"); + maxLength = (size_t)(std::ceil(maxLength / (float)step) * step); + + // this should be only one class label per line on input, hence restricting length to 1 + std::vector localMaxes(numFiles, maxLength); + auto inputTypes = options_->get>("input-types", {}); + for(int i = 0; i < inputTypes.size(); ++i) + if(inputTypes[i] == "class") + localMaxes[i] = 1; + + size_t maxBatch = 512; + bool fits = true; + while(fits) { + std::vector lengths(numFiles, first); + for(int j = 0; j < lengths.size(); ++j) // apply length restrictions + lengths[j] = std::min(lengths[j], localMaxes[j]); + + auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, maxBatch, options_); + auto cost = model->build(graph, batch); + fits = graph->fits(); + if(fits) + maxBatch *= 2; + } + + // Do a binary search for maxmimum batch size that fits into given workspace memory + // for a tested sentence length. + for(size_t i = step; i <= maxLength; i += step) { + size_t start = 1; + size_t end = maxBatch; + + std::vector lengths(numFiles, i); + for(int j = 0; j < lengths.size(); ++j) // apply length restrictions + lengths[j] = std::min(lengths[j], localMaxes[j]); + fits = true; + + do { + size_t current = (start + end) / 2; + auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, current, options_); + auto cost = model->build(graph, batch); + fits = graph->fits(); + + LOG(debug, "[batching] length: {} - size: {} - fits: {}", lengths[0], current, fits); + + if(fits) { + stats->add(batch, multiplier); + start = current + 1; + } else { + end = current - 1; + } + } while(end - start > step); + + maxBatch = start; + } + return stats; +} + +void GraphGroup::setTypicalTrgBatchWords(size_t typicalTrgBatchWords) { // needed for dynamic MB scaling + typicalTrgBatchWords_ = typicalTrgBatchWords; +} + +} \ No newline at end of file diff --git a/src/training/graph_group.h b/src/training/graph_group.h index 56b8afe3d..012f78ef9 100644 --- a/src/training/graph_group.h +++ b/src/training/graph_group.h @@ -19,12 +19,14 @@ class GraphGroup { protected: Ptr options_; Ptr opt_; // the optimizer + Ptr scheduler_; // scheduler that keeps track of how much has been processed + bool finalized_{false}; // 'true' if training has completed (further updates are no longer allowed) size_t typicalTrgBatchWords_{ 0 }; // for dynamic batch sizing: typical batch size in words public: - GraphGroup(Ptr options) : options_(options), opt_(Optimizer(options)) {} + GraphGroup(Ptr options); virtual ~GraphGroup() {} @@ -34,13 +36,9 @@ class GraphGroup { virtual void save(bool isFinal = false) = 0; - void validate() { - ABORT_IF(finalized_, "Training has already finished."); - } + void validate(); - virtual void finalize() { - finalized_ = true; - } + virtual void finalize(); virtual void setScheduler(Ptr scheduler) = 0; @@ -57,158 +55,9 @@ class GraphGroup { Ptr collectStats(Ptr graph, Ptr model, const std::vector>& vocabs, - double multiplier = 1.) { - auto stats = New(); - - size_t numFiles = options_->get>("train-sets").size(); - - // Initialize first batch to step size - size_t first = options_->get("mini-batch-fit-step"); - - // Increase batch size and sentence length by this step size - size_t step = options_->get("mini-batch-fit-step"); - - size_t maxLength = options_->get("max-length"); - maxLength = (size_t)(std::ceil(maxLength / (float)step) * step); - - // this should be only one class label per line on input, hence restricting length to 1 - std::vector localMaxes(numFiles, maxLength); - auto inputTypes = options_->get>("input-types", {}); - for(int i = 0; i < inputTypes.size(); ++i) - if(inputTypes[i] == "class") - localMaxes[i] = 1; - - size_t maxBatch = 512; - bool fits = true; - while(fits) { - std::vector lengths(numFiles, first); - for(int j = 0; j < lengths.size(); ++j) // apply length restrictions - lengths[j] = std::min(lengths[j], localMaxes[j]); - - auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, maxBatch, options_); - auto cost = model->build(graph, batch); - fits = graph->fits(); - if(fits) - maxBatch *= 2; - } - - // Do a binary search for maxmimum batch size that fits into given workspace memory - // for a tested sentence length. - for(size_t i = step; i <= maxLength; i += step) { - size_t start = 1; - size_t end = maxBatch; - - std::vector lengths(numFiles, i); - for(int j = 0; j < lengths.size(); ++j) // apply length restrictions - lengths[j] = std::min(lengths[j], localMaxes[j]); - fits = true; - - do { - size_t current = (start + end) / 2; - auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, current, options_); - auto cost = model->build(graph, batch); - fits = graph->fits(); - - LOG(debug, "[batching] length: {} - size: {} - fits: {}", lengths[0], current, fits); - - if(fits) { - stats->add(batch, multiplier); - start = current + 1; - } else { - end = current - 1; - } - } while(end - start > step); - - maxBatch = start; - } - return stats; - } - - void setTypicalTrgBatchWords(size_t typicalTrgBatchWords) { // needed for dynamic MB scaling - typicalTrgBatchWords_ = typicalTrgBatchWords; - } -}; - -/** - * Base class for multi-node versions of GraphGroups. - */ -class MultiNodeGraphGroupBase : public GraphGroup { - using Base = GraphGroup; - -protected: - Ptr mpi_; // all MPI-like communication goes through this + double multiplier = 1.); - /** Devices (GPUs) on this node. */ - std::vector devices_; // [num local GPUs] - - /** Graph builders for clients (which run forward and backward passes). */ - std::vector> clientBuilders_; - - /** Graphs of clients. One entry per GPU on this node. */ - std::vector> clientGraphs_; // [num local GPUs] - -public: - MultiNodeGraphGroupBase(Ptr options, Ptr mpi) - : Base(options), mpi_(mpi) { - - // Set up devices for this node - std::vector devices; // set of GPU device ids for this MPI process - for (auto& d : Config::getDevices(options_)) - devices.push_back(d.no); - loadDeviceConfig(devices); // set up numberClientsOfNodes_[] and devices_[] - - // Create builders and graphs for clients; that is, for each GPU we use on this node. - for (size_t i = 0; i < devices_.size(); i++) { - clientGraphs_.push_back(New()); - clientGraphs_[i]->setDevice({ devices_[i], DeviceType::gpu }); - clientGraphs_[i]->reserveWorkspaceMB(options_->get("workspace")); - clientBuilders_.push_back(models::createCriterionFunctionFromOptions(options_, models::usage::training)); - } - } - - /** - * Load the GPU configuration of this node (i.e. which GPUs to use) and the - * number of GPUs on the other nodes. - */ - // deviceConfig has this format - // - for each node - // - number of GPUs on that node - // - GPU ids for that node - // e.g. 0:0 1 1: 2 3 -> (2, (0, 1)) (2, (2,3)) - void loadDeviceConfig(std::vector deviceConfig) { - // parse device config array - size_t index = 0; // cursor for next() - auto next = [&]() { // helper function to get the next item - ABORT_IF(index == deviceConfig.size(), "mal-formed device config array??"); - return deviceConfig[index++]; - }; - std::vector> allDevices(mpi_->numMPIProcesses()); - for (auto& devices : allDevices) { - devices.resize(next()); - for (auto& device : devices) - device = next(); - } - ABORT_IF(index != deviceConfig.size(), "mal-formed device config array??"); - - // validate - ABORT_IF(allDevices.front().size() == 0, "no devices specified??"); - for (auto& devices : allDevices) { - ABORT_IF(devices.size() != allDevices.front().size(), "all MPI nodes must use the same number of devices"); - } - - // get our own config - devices_ = allDevices[mpi_->myMPIRank()]; - - // log - LOG(info, "[mpi rank {}] device configuration", mpi_->myMPIRank()); - for (auto& device : devices_) - LOG(info, "[mpi rank {}] - {}", mpi_->myMPIRank(), device); - } - - virtual void finalize() override { - if (mpi_) - finalizeMPI(std::move(mpi_)); - Base::finalize(); - } + void setTypicalTrgBatchWords(size_t typicalTrgBatchWords); }; + } // namespace marian From dd065420cbdf98244c3b4383f33411ad2067bc65 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Sat, 14 Mar 2020 09:53:54 -0700 Subject: [PATCH 05/62] bump version --- CHANGELOG.md | 2 ++ VERSION | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e62a9c00..427526486 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] ### Changed +- Changed compile flags -Ofast to -O3 and remove --ffinite-math +- Moved old graph groups to depracated folder - Make cublas and cusparse handle inits lazy to save memory when unused - Improved handling for graceful shutdown upon receiving SIGTERM. SIGTERM now also interrupts batch prefetching, which runs in a separate thread. diff --git a/VERSION b/VERSION index ba1e8bf0b..b95e90dc7 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v1.9.1 +v1.9.2 From 68581a6a4aa64c7e878636fb33873e79a8be202c Mon Sep 17 00:00:00 2001 From: Young Jin Kim Date: Wed, 25 Mar 2020 02:52:17 +0000 Subject: [PATCH 06/62] Merged PR 11831: Change the weight matrix quantization to use 7-bit min/max quantization to avoid overflow 1. Change the weight matrix quantization to use 7-bit min/max quantization -> This resolves all the overflow issue, because weight and activations are quantized by min/max range. 2. Clip fp16 quantization to avoid overflow 3. Fix windows build errors (cmake options, vcproj file) 4. int8 pack model (encoder -> fp16) --- src/common/config_parser.cpp | 4 + .../cpu/fbgemm/expression_graph_packable.h | 16 +- src/tensors/cpu/fbgemm/packed_gemm.cpp | 187 ++++++++++++------ src/tensors/cpu/fbgemm/packed_gemm.h | 10 +- src/translator/nth_element.cpp | 2 +- vs/Marian.vcxproj | 13 +- vs/Marian.vcxproj.filters | 45 +---- 7 files changed, 160 insertions(+), 117 deletions(-) diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index 9c711eaad..d0155be9a 100755 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -841,11 +841,15 @@ Ptr ConfigParser::parseOptions(int argc, char** argv, bool doValidate){ auto buildInfo = get("build-info"); if(!buildInfo.empty() && buildInfo != "false") { +#ifndef _MSC_VER // cmake build options are not available on MSVC based build. if(buildInfo == "all") std::cerr << cmakeBuildOptionsAdvanced() << std::endl; else std::cerr << cmakeBuildOptions() << std::endl; exit(0); +#else // _MSC_VER + ABORT("build-info is not available on MSVC based build."); +#endif // _MSC_VER } // get paths to extra config files diff --git a/src/tensors/cpu/fbgemm/expression_graph_packable.h b/src/tensors/cpu/fbgemm/expression_graph_packable.h index 743b7c8cb..4c2828955 100644 --- a/src/tensors/cpu/fbgemm/expression_graph_packable.h +++ b/src/tensors/cpu/fbgemm/expression_graph_packable.h @@ -35,10 +35,13 @@ class ExpressionGraphPackable : public ExpressionGraph { Tensor val = p.second->val(); // save as packed format - // @TODO Hardcoded to find packable weights - all the weights used for affine op (fp16), all the weights used for affine op and dot op (int8) + // @TODO Hardcoded to find packable weights + // int8 - quantize decoder only for better quality, all the weights used for affine op and dot op (int8) + // fp16 - all the weights used for affine op (fp16) if ((gemmElementType == Type::packed8avx2 || gemmElementType == Type::packed8avx512) - && (pName.find("_W") == pName.length() - 3 || pName.find("_W") == pName.length() - 2)) { - #if USE_FBGEMM + && (pName.find("_W") == pName.length() - 3 || pName.find("_W") == pName.length() - 2) + && pName.find("encoder") == std::string::npos) { +#if USE_FBGEMM using namespace marian::cpu::variant; // packing information - size int nrow; @@ -82,7 +85,10 @@ class ExpressionGraphPackable : public ExpressionGraph { #else ABORT("Packed type {} only supported when compiled with -DUSE_FBGEMM=on", gemmElementType); #endif - } else if (gemmElementType == Type::packed16 && pName.find("_W") == pName.length() - 3) { + // fp16 quantization option + encoders for int8 quantized models + } else if ((gemmElementType == Type::packed16 && pName.find("_W") == pName.length() - 3) + || ((gemmElementType == Type::packed8avx2 || gemmElementType == Type::packed8avx512) + && (pName.find("_W") == pName.length() - 3 || pName.find("_W") == pName.length() - 2))) { #if USE_FBGEMM using namespace marian::cpu::variant; @@ -123,7 +129,7 @@ class ExpressionGraphPackable : public ExpressionGraph { io::Item item; item.name = pName; item.shape = val->shape(); - item.type = gemmElementType; + item.type = Type::packed16; // Use the actual memory as this will be aligned and padded. // When memory mapping this is required. Shape keeps track of diff --git a/src/tensors/cpu/fbgemm/packed_gemm.cpp b/src/tensors/cpu/fbgemm/packed_gemm.cpp index a98d5e4ac..064c3c2be 100644 --- a/src/tensors/cpu/fbgemm/packed_gemm.cpp +++ b/src/tensors/cpu/fbgemm/packed_gemm.cpp @@ -76,22 +76,31 @@ const int PACK16_PADDING = 1024; // This is a memory space to store auxiliary variables for FBGEMM (e.g. block row, block column, kernel_ncol_blocks and etc.) const int PACK16_SPECIALMEM = 256; +// This is the maximum value of FP16 type. There is a template type implementation, but it doesn't work on windows. +// To keep the consistent result, just use the constant value instead of #ifdef _MSC_VER. +// Template type implementation: float FP16_MAX = NumericLimits(Type::float16).max; +const float FP16_MAX = 65504.f; + +// This function clips a value into a [min, max] range +inline float clip(float value, float min, float max) { + return std::max(min, std::min(value, max)); +} + // This is copied from FBGEMM code // A better way? // will be removed, when FBGEMM api is changed // blocked row-major format address arithmetic -/** - * Returns the memory address in the packed (block formatted) matrix array of a specific element - * indexed by the original non-packed array. - * - * @param r_ row index in the original matrix - * @param c_ column index in the original matrix - * @param brow_ row wide block index - * @param bcol_ column wide block index - * @param nbrow_ number of blocks in row - * @param nbcol_ number of blocks in column - * @param last_brow_ row number of the last block - */ +// +// Returns the memory address in the packed (block formatted) matrix array of a specific element +// indexed by the original non-packed array. +// +// @param r_ row index in the original matrix +// @param c_ column index in the original matrix +// @param brow_ row wide block index +// @param bcol_ column wide block index +// @param nbrow_ number of blocks in row +// @param nbcol_ number of blocks in column +// @param last_brow_ row number of the last block inline uint64_t addr(const int r_, const int c_, const int brow_, @@ -114,6 +123,15 @@ inline uint64_t addr(const int r_, return index; } +// Returns a value in 2D array with the row, column index (i, j) and transposed flag. +// The number of rows and columns needs to be passed. +// The transposed flag indicates if the underlying data needs to be accessed in a tranposed layout or not. +inline float getVal2dArr(const float* data, size_t i, size_t j, size_t rows, size_t cols, bool transposed) { + ABORT_IF(i >= rows, "Row index {} exceeds the number of rows {}.", i, rows); + ABORT_IF(j >= cols, "Column index {} exceeds the number of columns {}.", j, cols); + return transposed ? data[j * rows + i] : data[i * cols + j]; +} + // Memory blocking factors (parameters) for packing into AVX2 int8 static const fbgemm::BlockingFactors Packed8Avx2BlockingFactors = { PackingTraits::MR, @@ -147,6 +165,12 @@ inline const fbgemm::BlockingFactors* getBlockingFactors(marian::Type packType) } } +// Returns the byte size of packed matrix in fp16. It's calculated by fbgemm's internal logic due to the paddings and different layouts. +// Packing with fp16 only targets AVX2 instruction sets for now. +// See '3rd_party/fbgemm/include/fbgemm/FbgemmFP16.h'. +// shape: shape of the tensor to be packed +// transpose: the matrix is transposed +// packsize (out): the size of the packed matrix in byte void fbgemmPacked16PackInfo(const marian::Shape& shape, const bool transpose, uint64_t& packsize) { @@ -154,6 +178,21 @@ void fbgemmPacked16PackInfo(const marian::Shape& shape, fbgemmPacked16PackInfo(shape, transpose, nrow, ncol, kernel_ncol_blocks, brow, bcol, last_brow, nbrow, nbcol, packsize); } +// Returns the byte size of packed matrix in fp16. It's calculated by fbgemm's internal logic due to the paddings and different layouts. +// This function returns some other extra variables +// Packing with fp16 only targets AVX2 instruction sets for now. +// See '3rd_party/fbgemm/include/fbgemm/FbgemmFP16.h'. +// shape: shape of the tensor to be packed +// transpose: the matrix is transposed +// nrow (out): the number of rows +// ncol (out): the number of columns +// kernel_ncol_blocks (out): the number of column blocks +// brow (out): the number of rows in a block +// bcol (out): the number of columns in a block +// last_brow (out): the number of rows in the last block +// nbrow (out): row index in a block +// nbcol (out): column index in a block +// packsize (out): the size of the packed matrix in byte void fbgemmPacked16PackInfo(const marian::Shape& shape, const bool transpose, int& nrow, @@ -178,6 +217,14 @@ void fbgemmPacked16PackInfo(const marian::Shape& shape, + PACK16_SPECIALMEM; } +// Returns the byte size of packed matrix in int8. It's calculated by fbgemm's internal logic due to the paddings and different layouts. +// See '3rd_party/fbgemm/src/PackBMatrix.cc'. +// shape: shape of the tensor to be packed +// packType: Type to be packed - packed8avx2 or packed8avx512 +// transpose: the matrix is transposed +// nrow (out): the number of rows +// ncol (out): the number of columns +// packsize (out): the size of the packed matrix in byte void fbgemmPacked8PackInfo(const marian::Shape& shape, const marian::Type packType, const bool transpose, @@ -221,6 +268,20 @@ inline void col_offsets_with_zero_pt_s8acc32( } } +// Pack a matrix (fp16) into cache utilization efficient way (block format) into fp16 +// out: output tensor - packed format +// inData: input tensor data - pointer of float data +// transpose: the matrix is transposed +// nrow: the number of rows +// ncol: the number of columns +// kernel_ncol_blocks: the number of column blocks +// brow: the number of rows in a block +// bcol: the number of columns in a block +// last_brow: the number of rows in the last block +// nbrow: row index in a block +// nbcol: column index in a block +// packsize: the size of the packed matrix +// (the number of fp16 elements + padding (1024) + extra temporary memory (256)) void fbgemmPacked16Pack(marian::Tensor out, const float* inData, // Packing is only available for 2D weight matrix in Marian. Otherwise, it's aborted in expanded_gemm.h. const bool transpose, @@ -258,20 +319,37 @@ void fbgemmPacked16Pack(marian::Tensor out, // pack the matrix for(int i = 0; i < nrow; i++) { for(int j = 0; j < ncol; j++) { - outmem[addr(i, j, brow, bcol, nbrow, nbcol, last_brow)] - = tconv(!transpose ? inData[i * ncol + j] : inData[i + nrow * j], *dummy); + float src = clip(transpose ? inData[i + nrow * j] : inData[i * ncol + j], -FP16_MAX, FP16_MAX); + outmem[addr(i, j, brow, bcol, nbrow, nbcol, last_brow)] = tconv(src, *dummy); } } delete dummy; } +// Pack a matrix (int8) into cache utilization efficient way (block format) together with quantization into int8 +// out: output tensor - packed format and quantized into int8 +// inData: input tensor data - pointer of float data +// packType: Type to be packed - packed8avx2 or packed8avx512 +// transpose: the matrix is transposed +// nrow: the number of rows +// ncol: the number of columns +// packsize: the size of the packed matrix +// (the size of int8 packed B from fbgemm:PackAWithQuantRowOffset + quantization scale, offset and zero point) +// quantRangeStdDevs: the range to be quantized for the original float data in multiples standard deviation +// the default value is 0.0f which means min/max quantization +// only a half range of normal int8 which is [-64, 63] used to avoid overflow +// during the accumulation in VPMADDUBSW instruction +// https://intel.github.io/mkl-dnn/dev_guide_int8_computations.html +// (e.g. 3.f means the original tensor is quantized +// from [mean - 3.f * standard deviation, mean + 3.f * standard deviation] to [-64, 63]) void fbgemmPacked8Pack(marian::Tensor out, const float* inData, const marian::Type packType, const bool transpose, const int nrow, const int ncol, - const uint64_t packsize) { + const uint64_t packsize, + const float quantRangeStdDevs) { int k = nrow; int n = ncol; int len = k * n; @@ -282,46 +360,43 @@ void fbgemmPacked8Pack(marian::Tensor out, const float* data = inData; float val = 0; - - if (transpose) { - for (int jj = 0; jj < n; jj++) { - float min = std::numeric_limits::max(), max = std::numeric_limits::min(); - double mean = 0, sqrsum = 0; - for (int ii = 0; ii < k; ii++) { - val = data[jj * k + ii]; - mean += val; - sqrsum += val * val; - } - mean /= k; - sqrsum /= k; - sqrsum -= mean * mean; - sqrsum = sqrt(sqrsum); - - min = (float)(mean - 7.0f*sqrsum); - max = (float)(mean + 7.0f*sqrsum); - bqScale[jj] = (max - min) / 255; - bqZeropoint[jj] = (int32_t)(127 - max / bqScale[jj]); - } - } else { - for (int jj = 0; jj < n; jj++) { - float min = std::numeric_limits::max(), max = std::numeric_limits::min(); - double mean = 0, sqrsum = 0; - for (int ii = 0; ii < k; ii++) { - val = data[jj + ii * n]; - mean += val; - sqrsum += val * val; - } - mean /= k; - sqrsum /= k; - sqrsum -= mean * mean; - sqrsum = sqrt(sqrsum); - - min = (float)(mean - 7.0f*sqrsum); - max = (float)(mean + 7.0f*sqrsum); - bqScale[jj] = (max - min) / 255; - bqZeropoint[jj] = (int32_t)(127 - max / bqScale[jj]); - } - } + + // Use half of the quantization range to prevent overflow of VPMADDUBSW + constexpr static int quantizedRange = 127; + constexpr static int quantizedMax = 63; + + // This routine compute the quantization range for each column - either one of min/max range or quantRangeStdDevs sigma range. + for (size_t jj = 0; jj < n; jj++) { // for each column, collect stats (min/max or mean/std.dev.) + float min = std::numeric_limits::max(), max = std::numeric_limits::min(); + double mean = 0, sqrsum = 0; + for (size_t ii = 0; ii < k; ii++) { // in a column, go throuhg all the rows and collect stats + val = getVal2dArr(data, ii, jj, k, n, transpose); + // If quantRangeStdDevs is 0.f, min/max values of the columns is used as a quantization range + if(quantRangeStdDevs == 0.f) { + if(min > val) + min = val; + if(max < val) + max = val; + } else { + // Quantize by std.dev. range + mean += val; + sqrsum += val * val; + } + } + // If a quantization range (in multiples of std. dev.) is given with a non-zero value, + // it calculate the range for this column (different quantization scale/offset are used for each column) + if(quantRangeStdDevs != 0.f) { + mean /= k; + sqrsum /= k; + sqrsum -= mean * mean; + sqrsum = sqrt(sqrsum); + min = (float)(mean - quantRangeStdDevs * sqrsum); + max = (float)(mean + quantRangeStdDevs * sqrsum); + } + // based on the quantization range, this computes the scale and offset for the quantization + bqScale[jj] = (max - min) / quantizedRange; + bqZeropoint[jj] = (int32_t)(quantizedMax - max / bqScale[jj]); + } // 2. quantize int8_t* quantized = 0; @@ -335,7 +410,7 @@ void fbgemmPacked8Pack(marian::Tensor out, TensorQuantizationParams bQuantParam; bQuantParam.scale = bqScale[jj]; bQuantParam.zero_point = bqZeropoint[jj]; - bQuantParam.precision = 8; + bQuantParam.precision = 7; // Use half of the quantization range to prevent overflow of VPMADDUBSW if (transpose) fbgemm::Quantize(data + jj * k, quantized + jj * k, k, bQuantParam); diff --git a/src/tensors/cpu/fbgemm/packed_gemm.h b/src/tensors/cpu/fbgemm/packed_gemm.h index d0a63ea99..694860d48 100644 --- a/src/tensors/cpu/fbgemm/packed_gemm.h +++ b/src/tensors/cpu/fbgemm/packed_gemm.h @@ -94,13 +94,21 @@ void fbgemmPacked16Pack(marian::Tensor out, // ncol: the number of columns // packsize: the size of the packed matrix // (the size of int8 packed B from fbgemm:PackAWithQuantRowOffset + quantization scale, offset and zero point) +// quantRangeStdDevs: the range to be quantized for the original float data in multiples standard deviation +// the default value is 0.0f which means min/max quantization +// only a half range of normal int8 which is [-64, 63] used to avoid overflow +// during the accumulation in VPMADDUBSW instruction +// https://intel.github.io/mkl-dnn/dev_guide_int8_computations.html +// (e.g. 3.f means the original tensor is quantized +// from [mean - 3.f * standard deviation, mean + 3.f * standard deviation] to [-64, 63]) void fbgemmPacked8Pack(marian::Tensor out, const float* inData, const marian::Type packType, const bool transpose, const int nrow, const int ncol, - const uint64_t packsize); // @TODO: change to size_t where appropriate + const uint64_t packsize, + const float quantRangeStdDevs = 0.f); // @TODO: change to size_t where appropriate // GEMM operation on the packed B matrix // C: output matrix diff --git a/src/translator/nth_element.cpp b/src/translator/nth_element.cpp index 8b2f89476..237d9b9da 100644 --- a/src/translator/nth_element.cpp +++ b/src/translator/nth_element.cpp @@ -56,7 +56,7 @@ class NthElementCPU { for(size_t i = 0; i < N; ++i) { int idx = idxs[i]; // since idxs is re-used for each batch, add batch offset to each idx to get absolute position - h_res_idx[pos] = idx + batchIdx * batchOffset; + h_res_idx[pos] = (int) (idx + batchIdx * batchOffset); h_res[pos] = scoresData[idx]; ++pos; } diff --git a/vs/Marian.vcxproj b/vs/Marian.vcxproj index 241aa307d..0cb4a5de0 100755 --- a/vs/Marian.vcxproj +++ b/vs/Marian.vcxproj @@ -1445,7 +1445,7 @@ true - + @@ -1454,10 +1454,8 @@ - - @@ -1653,7 +1651,6 @@ - @@ -1755,18 +1752,12 @@ - - - - - - @@ -1906,8 +1897,6 @@ false false - - diff --git a/vs/Marian.vcxproj.filters b/vs/Marian.vcxproj.filters index a4cbc827a..bb6080ae8 100755 --- a/vs/Marian.vcxproj.filters +++ b/vs/Marian.vcxproj.filters @@ -94,18 +94,12 @@ training - - training - training training - - training - training @@ -226,9 +220,6 @@ 3rd_party\yaml-cpp - - training - command @@ -883,6 +874,9 @@ tensors\cpu\fbgemm + + training + @@ -1288,9 +1282,6 @@ common - - common - common @@ -1531,12 +1522,6 @@ training - - training - - - training - training @@ -1555,15 +1540,6 @@ training - - training\gradient_dropping - - - training\gradient_dropping - - - training\gradient_dropping\gpu - translator @@ -1642,9 +1618,6 @@ translator - - training - command @@ -2373,12 +2346,6 @@ {880c8f51-3306-4d80-a682-7242341b0098} - - {880c8f51-3306-4d80-a682-7242341b0101} - - - {880c8f51-3306-4d80-a682-7242341b0104} - {880c8f51-3306-4d80-a682-7242341b0107} @@ -2703,11 +2670,5 @@ tensors\gpu - - training\gradient_dropping\gpu - - - training\gradient_dropping\gpu - \ No newline at end of file From d0fa14e2640814a02ec8c99ed028c9d3b50744c6 Mon Sep 17 00:00:00 2001 From: Young Jin Kim Date: Fri, 27 Mar 2020 21:44:31 +0000 Subject: [PATCH 07/62] Merged PR 12243: For int8 quantized model, use int8 quantization for encoders as well For int8 quantized model, use int8 quantization for encoders as well. The quality difference between fp16 encoder and int8 encoder is small, but they have quite amount of speed difference. --- src/tensors/cpu/fbgemm/expression_graph_packable.h | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/tensors/cpu/fbgemm/expression_graph_packable.h b/src/tensors/cpu/fbgemm/expression_graph_packable.h index 4c2828955..f5b05c302 100644 --- a/src/tensors/cpu/fbgemm/expression_graph_packable.h +++ b/src/tensors/cpu/fbgemm/expression_graph_packable.h @@ -36,11 +36,10 @@ class ExpressionGraphPackable : public ExpressionGraph { // save as packed format // @TODO Hardcoded to find packable weights - // int8 - quantize decoder only for better quality, all the weights used for affine op and dot op (int8) - // fp16 - all the weights used for affine op (fp16) + // int8 - all the weights used for affine op and dot op + // fp16 - all the weights used for affine op if ((gemmElementType == Type::packed8avx2 || gemmElementType == Type::packed8avx512) - && (pName.find("_W") == pName.length() - 3 || pName.find("_W") == pName.length() - 2) - && pName.find("encoder") == std::string::npos) { + && (pName.find("_W") == pName.length() - 3 || pName.find("_W") == pName.length() - 2)) { #if USE_FBGEMM using namespace marian::cpu::variant; // packing information - size @@ -85,10 +84,8 @@ class ExpressionGraphPackable : public ExpressionGraph { #else ABORT("Packed type {} only supported when compiled with -DUSE_FBGEMM=on", gemmElementType); #endif - // fp16 quantization option + encoders for int8 quantized models - } else if ((gemmElementType == Type::packed16 && pName.find("_W") == pName.length() - 3) - || ((gemmElementType == Type::packed8avx2 || gemmElementType == Type::packed8avx512) - && (pName.find("_W") == pName.length() - 3 || pName.find("_W") == pName.length() - 2))) { + // fp16 quantization option + } else if (gemmElementType == Type::packed16 && pName.find("_W") == pName.length() - 3) { #if USE_FBGEMM using namespace marian::cpu::variant; From 71e0f0b33fd60cf6a12df148f0675d153e41fd23 Mon Sep 17 00:00:00 2001 From: Roman Grundkiewicz Date: Fri, 10 Apr 2020 21:01:56 +0100 Subject: [PATCH 08/62] Support tab-separated inputs (#617) * Add basic support for TSV inputs * Fix mini-batch-fit for TSV inputs * Abort if shuffling data from stdin * Fix terminating training with data from STDIN * Allow creating vocabs from TSV files * Add comments; clean creation of vocabs from TSV files * Guess --tsv-size based on the model type * Add shortcut for STDIN inputs * Rename --tsv-size to --tsv-fields * Allow only one 'stdin' in --train-sets * Properly create separate vocabularies from a TSV file * Clearer logging message * Add error message for wrong number of valid sets if --tsv is used * Use --no-shuffle instead of --shuffle in the error message * Fix continuing training from STDIN * Update CHANGELOG * Support both 'stdin' and '-' * Guess --tsv-fields from dim-vocabs if special:model.yml available * Update error messages * Move variable outside the loop * Refactorize utils::splitTsv; add unit tests * Support '-' as stdin; refactorize; add comments * Abort if excessive field(s) in the TSV input * Add a TODO on passing one vocab with fully-tied embeddings * Remove the unit test with excessive tab-separated fields --- CHANGELOG.md | 16 +-- src/CMakeLists.txt | 1 + src/common/config.cpp | 40 ++++++- src/common/config.h | 4 +- src/common/config_parser.cpp | 35 +++++- src/common/config_parser.h | 1 + src/common/config_validator.cpp | 25 +++- src/common/file_stream.cpp | 6 +- src/common/file_stream.h | 2 + src/common/file_utils.cpp | 28 +++++ src/common/file_utils.h | 18 +++ src/common/filesystem.h | 4 +- src/common/utils.cpp | 28 +++++ src/common/utils.h | 4 + src/data/corpus.cpp | 25 +++- src/data/corpus_base.cpp | 197 +++++++++++++++++++++++++------- src/data/corpus_base.h | 5 +- src/data/default_vocab.cpp | 1 - src/tests/units/CMakeLists.txt | 1 + src/tests/units/utils_tests.cpp | 36 ++++++ src/training/graph_group.h | 83 +++++++++++++- src/training/scheduler.h | 17 ++- src/training/training_state.h | 2 +- 23 files changed, 502 insertions(+), 77 deletions(-) create mode 100644 src/common/file_utils.cpp create mode 100644 src/common/file_utils.h create mode 100644 src/tests/units/utils_tests.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 427526486..15c5c7387 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [1.9.0] - 2020-03-10 ### Added +- Training and scoring from STDIN +- Support for tab-separated inputs, added ptions --tsv and --tsv-fields - An option to print cached variables from CMake - Add support for compiling on Mac (and clang) - An option for resetting stalled validation metrics @@ -38,15 +40,15 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Support for 16-bit packed models with FBGEMM - Multiple separated parameter types in ExpressionGraph, currently inference-only - Safe handling of sigterm signal -- Automatic vectorization of elementwise operations on CPU for tensors dims that +- Automatic vectorization of elementwise operations on CPU for tensors dims that are divisible by 4 (AVX) and 8 (AVX2) -- Replacing std::shared_ptr with custom IntrusivePtr for small objects like +- Replacing std::shared_ptr with custom IntrusivePtr for small objects like Tensors, Hypotheses and Expressions. - Fp16 inference working for translation - Gradient-checkpointing ### Fixed -- Replace value for INVALID_PATH_SCORE with std::numer_limits::lowest() +- Replace value for INVALID_PATH_SCORE with std::numer_limits::lowest() to avoid overflow with long sequences - Break up potential circular references for GraphGroup* - Fix empty source batch entries with batch purging @@ -57,16 +59,16 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - FastOpt now reads "n" and "y" values as strings, not as boolean values - Fixed multiple reduction kernels on GPU - Fixed guided-alignment training with cross-entropy -- Replace IntrusivePtr with std::uniq_ptr in FastOpt, fixes random segfaults +- Replace IntrusivePtr with std::uniq_ptr in FastOpt, fixes random segfaults due to thread-non-safty of reference counting. - Make sure that items are 256-byte aligned during saving - Make explicit matmul functions respect setting of cublasMathMode - Fix memory mapping for mixed paramter models - Removed naked pointer and potential memory-leak from file_stream.{cpp,h} - Compilation for GCC >= 7 due to exception thrown in destructor -- Sort parameters by lexicographical order during allocation to ensure consistent +- Sort parameters by lexicographical order during allocation to ensure consistent memory-layout during allocation, loading, saving. -- Output empty line when input is empty line. Previous behavior might result in +- Output empty line when input is empty line. Previous behavior might result in hallucinated outputs. - Compilation with CUDA 10.1 @@ -77,7 +79,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Return error signal on SIGTERM - Dropped support for CUDA 8.0, CUDA 9.0 is now minimal requirement - Removed autotuner for now, will be switched back on later -- Boost depdendency is now optional and only required for marian_server +- Boost depdendency is now optional and only required for marian_server - Dropped support for g++-4.9 - Simplified file stream and temporary file handling - Unified node intializers, same function API. diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index acae400fd..fde525600 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -25,6 +25,7 @@ add_library(marian STATIC common/filesystem.cpp common/file_stream.cpp common/signal_handling.cpp + common/file_utils.cpp common/types.cpp data/alignment.cpp diff --git a/src/common/config.cpp b/src/common/config.cpp index a6ce44c4d..ed6b8267a 100644 --- a/src/common/config.cpp +++ b/src/common/config.cpp @@ -49,11 +49,12 @@ void Config::initialize(ConfigParser const& cp) { } // load model parameters + bool loaded = false; if(mode == cli::mode::translation || mode == cli::mode::server) { auto model = get>("models")[0]; try { if(!get("ignore-model-config")) - loadModelParameters(model); + loaded = loadModelParameters(model); } catch(std::runtime_error& ) { LOG(info, "[config] No model configuration found in model file"); } @@ -64,13 +65,42 @@ void Config::initialize(ConfigParser const& cp) { if(filesystem::exists(model) && !get("no-reload")) { try { if(!get("ignore-model-config")) - loadModelParameters(model); + loaded = loadModelParameters(model); } catch(std::runtime_error&) { LOG(info, "[config] No model configuration found in model file"); } } } + // guess --tsv-fields (the number of streams) if not set + if(get("tsv") && get("tsv-fields") == 0) { + size_t tsvFields = 0; + if(loaded) { + // model.npz has properly set vocab dimensions in special:model.yml, + // so we may use them to determine the number of streams + for(auto dim : get>("dim-vocabs")) + if(dim != 0) // language models have a fake extra vocab + ++tsvFields; + // For translation there is no target stream + if((mode == cli::mode::translation || mode == cli::mode::server) && tsvFields > 1) + --tsvFields; + } else { + // TODO: This is very britle, find a better solution + // If parameters from model.npz special:model.yml were not loaded, + // guess the number of inputs and outputs based on the model type name. + auto modelType = get("type"); + + tsvFields = 1; + if(modelType.find("multi-", 0) != std::string::npos) // is a dual-source model + tsvFields += 1; + if(mode == cli::mode::training || mode == cli::mode::scoring) + if(modelType.rfind("lm", 0) != 0) // unless it is a language model + tsvFields += 1; + } + + config_["tsv-fields"] = tsvFields; + } + // echo full configuration log(); @@ -124,16 +154,18 @@ void Config::save(const std::string& name) { out << *this; } -void Config::loadModelParameters(const std::string& name) { +bool Config::loadModelParameters(const std::string& name) { YAML::Node config; io::getYamlFromModel(config, "special:model.yml", name); override(config); + return true; } -void Config::loadModelParameters(const void* ptr) { +bool Config::loadModelParameters(const void* ptr) { YAML::Node config; io::getYamlFromModel(config, "special:model.yml", ptr); override(config); + return true; } void Config::override(const YAML::Node& params) { diff --git a/src/common/config.h b/src/common/config.h index d4784af73..255c50add 100644 --- a/src/common/config.h +++ b/src/common/config.h @@ -77,8 +77,8 @@ class Config { } YAML::Node getModelParameters(); - void loadModelParameters(const std::string& name); - void loadModelParameters(const void* ptr); + bool loadModelParameters(const std::string& name); + bool loadModelParameters(const void* ptr); std::vector getDevices(size_t myMPIRank = 0, size_t numRanks = 1); diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index d0155be9a..05a6ccb1d 100755 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -375,6 +375,7 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) { "10000u"); addSuboptionsInputLength(cli); + addSuboptionsTSV(cli); // data management options cli.add("--shuffle", @@ -497,8 +498,10 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) { {"float32", "float32", "float32"}); cli.add>("--cost-scaling", "Dynamic cost scaling for mixed precision training: " - "power of 2, scaling window, scaling factor, tolerance, range, minimum factor")->implicit_val("7.f 2000 2.f 0.05f 10 1.f"); - cli.add("--normalize-gradient", "Normalize gradient by multiplying with no. devices / total labels"); + "power of 2, scaling window, scaling factor, tolerance, range, minimum factor") + ->implicit_val("7.f 2000 2.f 0.05f 10 1.f"); + cli.add("--normalize-gradient", + "Normalize gradient by multiplying with no. devices / total labels"); // multi-node training cli.add("--multi-node", @@ -623,8 +626,9 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) { "Keep the output segmented into SentencePiece subwords"); #endif - addSuboptionsDevices(cli); addSuboptionsInputLength(cli); + addSuboptionsTSV(cli); + addSuboptionsDevices(cli); addSuboptionsBatching(cli); cli.add("--optimize", @@ -684,6 +688,7 @@ void ConfigParser::addOptionsScoring(cli::CLIWrapper& cli) { ->implicit_val("1"), addSuboptionsInputLength(cli); + addSuboptionsTSV(cli); addSuboptionsDevices(cli); addSuboptionsBatching(cli); @@ -791,6 +796,15 @@ void ConfigParser::addSuboptionsInputLength(cli::CLIWrapper& cli) { // clang-format on } +void ConfigParser::addSuboptionsTSV(cli::CLIWrapper& cli) { + // clang-format off + cli.add("--tsv", + "Tab-separated input"); + cli.add("--tsv-fields", + "Number of fields in the TSV input, guessed based on the model type"); + // clang-format on +} + void ConfigParser::addSuboptionsULR(cli::CLIWrapper& cli) { // clang-format off // support for universal encoder ULR https://arxiv.org/pdf/1802.05368.pdf @@ -865,6 +879,21 @@ Ptr ConfigParser::parseOptions(int argc, char** argv, bool doValidate){ cli::processPaths(config_, cli::InterpolateEnvVars, PATHS); } + // Option shortcuts for input from STDIN for trainer and scorer + if(mode_ == cli::mode::training || mode_ == cli::mode::scoring) { + auto trainSets = get>("train-sets"); + YAML::Node config; + // Assume the input will come from STDIN if --tsv is set but no --train-sets are given + if(get("tsv") && trainSets.empty()) { + config["train-sets"].push_back("stdin"); + // Assume the input is in TSV format if --train-sets is set to "stdin" + } else if(trainSets.size() == 1 && (trainSets[0] == "stdin" || trainSets[0] == "-")) { + config["tsv"] = true; + } + if(!config.IsNull()) + cli_.updateConfig(config, cli::OptionPriority::CommandLine, "A shortcut for STDIN failed."); + } + if(doValidate) { ConfigValidator(config_).validateOptions(mode_); } diff --git a/src/common/config_parser.h b/src/common/config_parser.h index 652a1d249..798ec6227 100644 --- a/src/common/config_parser.h +++ b/src/common/config_parser.h @@ -135,6 +135,7 @@ class ConfigParser { void addSuboptionsDevices(cli::CLIWrapper&); void addSuboptionsBatching(cli::CLIWrapper&); void addSuboptionsInputLength(cli::CLIWrapper&); + void addSuboptionsTSV(cli::CLIWrapper&); void addSuboptionsULR(cli::CLIWrapper&); // Extract paths to all config files found in the config object. diff --git a/src/common/config_validator.cpp b/src/common/config_validator.cpp index 609172489..cf46f7381 100644 --- a/src/common/config_validator.cpp +++ b/src/common/config_validator.cpp @@ -70,9 +70,20 @@ void ConfigValidator::validateOptionsParallelData() const { auto trainSets = get>("train-sets"); ABORT_IF(trainSets.empty(), "No train sets given in config file or on command line"); - auto vocabs = get>("vocabs"); - ABORT_IF(!vocabs.empty() && vocabs.size() != trainSets.size(), - "There should be as many vocabularies as training sets"); + auto numVocabs = get>("vocabs").size(); + ABORT_IF(!get("tsv") && numVocabs > 0 && numVocabs != trainSets.size(), + "There should be as many vocabularies as training files"); + + // disallow, for example --tsv --train-sets file1.tsv file2.tsv + ABORT_IF(get("tsv") && trainSets.size() != 1, + "A single file must be provided with --train-sets (or stdin) for a tab-separated input"); + + // disallow, for example --train-sets stdin stdin or --train-sets stdin file.tsv + ABORT_IF(trainSets.size() > 1 + && std::any_of(trainSets.begin(), + trainSets.end(), + [](const std::string& s) { return (s == "stdin") || (s == "-"); }), + "Only one 'stdin' or '-' in --train-sets is allowed"); } void ConfigValidator::validateOptionsScoring() const { @@ -94,7 +105,7 @@ void ConfigValidator::validateOptionsTraining() const { ABORT_IF(has("embedding-vectors") && get>("embedding-vectors").size() != trainSets.size() && !get>("embedding-vectors").empty(), - "There should be as many embedding vector files as training sets"); + "There should be as many embedding vector files as training files"); filesystem::Path modelPath(get("model")); @@ -105,10 +116,14 @@ void ConfigValidator::validateOptionsTraining() const { ABORT_IF(!modelDir.empty() && !filesystem::isDirectory(modelDir), "Model directory does not exist"); + std::string errorMsg = "There should be as many validation files as training files"; + if(get("tsv")) + errorMsg += ". If the training set is in the TSV format, validation sets have to also be a single TSV file"; + ABORT_IF(has("valid-sets") && get>("valid-sets").size() != trainSets.size() && !get>("valid-sets").empty(), - "There should be as many validation sets as training sets"); + errorMsg); // validations for learning rate decaying ABORT_IF(get("lr-decay") > 1.f, "Learning rate decay factor greater than 1.0 is unusual"); diff --git a/src/common/file_stream.cpp b/src/common/file_stream.cpp index 815048304..8717de4cc 100755 --- a/src/common/file_stream.cpp +++ b/src/common/file_stream.cpp @@ -22,7 +22,7 @@ InputFileStream::InputFileStream(const std::string &file) ABORT_IF(!marian::filesystem::exists(file_), "File '{}' does not exist", file); streamBuf1_.reset(new std::filebuf()); - auto ret = static_cast(streamBuf1_.get())->open(file.c_str(), std::ios::in | std::ios::binary); + auto ret = static_cast(streamBuf1_.get())->open(file.c_str(), std::ios::in | std::ios::binary); ABORT_IF(!ret, "File cannot be opened", file); ABORT_IF(ret != streamBuf1_.get(), "Return value is not equal to streambuf pointer, that is weird"); @@ -84,6 +84,10 @@ OutputFileStream::~OutputFileStream() { this->flush(); } +std::string OutputFileStream::getFileName() const { + return file_.string(); +} + /////////////////////////////////////////////////////////////////////////////////////////////// TemporaryFile::TemporaryFile(const std::string &base, bool earlyUnlink) : OutputFileStream(), unlink_(earlyUnlink) { diff --git a/src/common/file_stream.h b/src/common/file_stream.h index 9bbf33599..8c4588342 100644 --- a/src/common/file_stream.h +++ b/src/common/file_stream.h @@ -62,6 +62,8 @@ class OutputFileStream : public std::ostream { explicit OutputFileStream(const std::string& file); virtual ~OutputFileStream(); + std::string getFileName() const; + template size_t write(const T* ptr, size_t num = 1) { std::ostream::write((char*)ptr, num * sizeof(T)); diff --git a/src/common/file_utils.cpp b/src/common/file_utils.cpp new file mode 100644 index 000000000..0ee844262 --- /dev/null +++ b/src/common/file_utils.cpp @@ -0,0 +1,28 @@ +#include "common/file_utils.h" +#include "common/utils.h" + +namespace marian { +namespace fileutils { + +void cut(const std::string& tsvIn, + Ptr tsvOut, + const std::vector& fields, + size_t numFields, + const std::string& sep /*= "\t"*/) { + std::vector tsvFields(numFields); + std::string line; + io::InputFileStream ioIn(tsvIn); + while(getline(ioIn, line)) { + tsvFields.clear(); + utils::splitTsv(line, tsvFields, numFields); // split tab-separated fields + for(size_t i = 0; i < fields.size(); ++i) { + *tsvOut << tsvFields[fields[i]]; + if(i < fields.size() - 1) + *tsvOut << sep; // concatenating fields with the custom separator + } + *tsvOut << std::endl; + } +}; + +} // namespace fileutils +} // namespace marian diff --git a/src/common/file_utils.h b/src/common/file_utils.h new file mode 100644 index 000000000..d8ab407d2 --- /dev/null +++ b/src/common/file_utils.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include + +#include "common/file_stream.h" + +namespace marian { +namespace fileutils { + +void cut(const std::string& tsvIn, + Ptr tsvOut, + const std::vector& fields, + size_t numFields, + const std::string& sep = "\t"); + +} // namespace utils +} // namespace marian diff --git a/src/common/filesystem.h b/src/common/filesystem.h index d7cb3da68..05315c332 100644 --- a/src/common/filesystem.h +++ b/src/common/filesystem.h @@ -115,5 +115,5 @@ namespace filesystem { using FilesystemError = Pathie::PathieError; -} -} +} // namespace filesystem +} // namespace marian diff --git a/src/common/utils.cpp b/src/common/utils.cpp index 3acb756d1..aded13c5d 100755 --- a/src/common/utils.cpp +++ b/src/common/utils.cpp @@ -67,6 +67,28 @@ void split(const std::string& line, } } +// the function guarantees that the output has as many elements as requested +void splitTsv(const std::string& line, std::vector& fields, size_t numFields) { + fields.clear(); + + size_t begin = 0; + size_t pos = 0; + for(size_t i = 0; i < numFields; ++i) { + pos = line.find('\t', begin); + if(pos == std::string::npos) { + fields.push_back(line.substr(begin)); + break; + } + fields.push_back(line.substr(begin, pos - begin)); + begin = pos + 1; + } + + if(fields.size() < numFields) // make sure there is as many elements as requested + fields.resize(numFields); + + ABORT_IF(pos != std::string::npos, "Excessive field(s) in the tab-separated line: '{}'", line); +} + std::vector split(const std::string& line, const std::string& del /*= " "*/, bool keepEmpty /*= false*/, @@ -103,6 +125,12 @@ std::string join(const std::vector& words, const std::string& del / return ss.str(); } +std::string join(const std::vector& nums, const std::string& del /*= " "*/) { + std::vector words(nums.size()); + std::transform(nums.begin(), nums.end(), words.begin(), [](int i) { return std::to_string(i); }); + return join(words, del); +} + // escapes a string for passing to popen, which uses /bin/sh to parse its argument string static std::string escapeForPOpen(const std::string& arg) { // e.g. abc -> 'abc'; my file.txt -> 'my file.txt'; $10 -> '$10'; it's -> 'it'\''s' diff --git a/src/common/utils.h b/src/common/utils.h index c3266bbf4..d576214bd 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -20,6 +20,9 @@ void splitAny(const std::string& line, const std::string& del = " ", bool keepEmpty = false); +// Split tab-separated line into the specified number of fields +void splitTsv(const std::string& line, std::vector& fields, size_t numFields); + std::vector split(const std::string& line, const std::string& del = " ", bool keepEmpty = false, @@ -29,6 +32,7 @@ std::vector splitAny(const std::string& line, bool keepEmpty = false); std::string join(const std::vector& words, const std::string& del = " "); +std::string join(const std::vector& words, const std::string& del = " "); std::string exec(const std::string& cmd, const std::vector& args = {}, const std::string& arg = ""); diff --git a/src/data/corpus.cpp b/src/data/corpus.cpp index aea232778..9490d0323 100755 --- a/src/data/corpus.cpp +++ b/src/data/corpus.cpp @@ -46,6 +46,8 @@ void Corpus::preprocessLine(std::string& line, size_t streamId) { } SentenceTuple Corpus::next() { + std::vector fields(tsvNumFields_); // used for handling TSV inputs + for(;;) { // (this is a retry loop for skipping invalid sentences) // get index of the current sentence size_t curId = pos_; // note: at end, pos_ == total size @@ -78,13 +80,21 @@ SentenceTuple Corpus::next() { } } - if(i > 0 && i == alignFileIdx_) { // @TODO: alignFileIdx == 0 possible? + if(i > 0 && i == alignFileIdx_) { addAlignmentToSentenceTuple(line, tup); } else if(i > 0 && i == weightFileIdx_) { addWeightsToSentenceTuple(line, tup); } else { - preprocessLine(line, i); - addWordsToSentenceTuple(line, i, tup); + if(tsv_) { // split TSV input and add each field into the sentence tuple + utils::splitTsv(line, fields, tsvNumFields_); + for(size_t j = 0; j < tsvNumFields_; ++j) { + preprocessLine(fields[j], j); + addWordsToSentenceTuple(fields[j], j, tup); + } + } else { + preprocessLine(line, i); + addWordsToSentenceTuple(line, i, tup); + } } } @@ -112,7 +122,8 @@ void Corpus::shuffle() { // reset to regular, non-shuffled reading // Call either reset() or shuffle(). -// @TODO: make shuffle() private, instad pass a shuffle() flag to reset(), to clarify mutual exclusiveness with shuffle() +// @TODO: make shuffle() private, instad pass a shuffle() flag to reset(), to clarify mutual +// exclusiveness with shuffle() void Corpus::reset() { corpusInRAM_.clear(); ids_.clear(); @@ -120,7 +131,7 @@ void Corpus::reset() { return; pos_ = 0; for (size_t i = 0; i < paths_.size(); ++i) { - if(paths_[i] == "stdin") { + if(paths_[i] == "stdin" || paths_[i] == "-") { files_[i].reset(new std::istream(std::cin.rdbuf())); // Probably not necessary, unless there are some buffers // that we want flushed. @@ -143,6 +154,10 @@ void Corpus::restore(Ptr ts) { void Corpus::shuffleData(const std::vector& paths) { LOG(info, "[data] Shuffling data"); + ABORT_IF(tsv_ && (paths[0] == "stdin" || paths[0] == "-"), + "Shuffling training data from STDIN is not supported. Add --no-shuffle or provide " + "training sets with --train-sets"); + size_t numStreams = paths.size(); size_t numSentences; diff --git a/src/data/corpus_base.cpp b/src/data/corpus_base.cpp index bb5fe735c..5eac4668d 100755 --- a/src/data/corpus_base.cpp +++ b/src/data/corpus_base.cpp @@ -1,5 +1,6 @@ #include +#include "common/file_utils.h" #include "data/corpus.h" #include "data/factored_vocab.h" @@ -36,9 +37,17 @@ CorpusBase::CorpusBase(const std::vector& paths, vocabs_(vocabs), maxLength_(options_->get("max-length")), maxLengthCrop_(options_->get("max-length-crop")), - rightLeft_(options_->get("right-left")) { - ABORT_IF(paths_.size() != vocabs_.size(), - "Number of corpus files and vocab files does not agree"); + rightLeft_(options_->get("right-left")), + tsv_(options_->get("tsv", false)), + tsvNumFields_(options->get("tsv-fields", 0)) { + // TODO: support passing only one vocab file if we have fully-tied embeddings + if(tsv_) { + ABORT_IF(tsvNumFields_ != vocabs_.size(), + "Number of TSV fields and vocab files does not agree"); + } else { + ABORT_IF(paths_.size() != vocabs_.size(), + "Number of corpus files and vocab files does not agree"); + } for(auto path : paths_) { UPtr strm(new io::InputFileStream(path)); @@ -53,7 +62,9 @@ CorpusBase::CorpusBase(Ptr options, bool translate) : DatasetBase(options), maxLength_(options_->get("max-length")), maxLengthCrop_(options_->get("max-length-crop")), - rightLeft_(options_->get("right-left")) { + rightLeft_(options_->get("right-left")), + tsv_(options_->get("tsv", false)), + tsvNumFields_(options->get("tsv-fields", 0)) { bool training = !translate; if(training) @@ -68,8 +79,13 @@ CorpusBase::CorpusBase(Ptr options, bool translate) vocabPaths = options_->get>("vocabs"); if(training) { - ABORT_IF(!vocabPaths.empty() && paths_.size() != vocabPaths.size(), - "Number of corpus files and vocab files does not agree"); + if(tsv_) { + ABORT_IF(!vocabPaths.empty() && tsvNumFields_ != vocabPaths.size(), + "Number of TSV fields and vocab files does not agree"); + } else { + ABORT_IF(!vocabPaths.empty() && paths_.size() != vocabPaths.size(), + "Number of corpus files and vocab files does not agree"); + } } // @TODO: check if size_t can be used instead of int @@ -77,60 +93,157 @@ CorpusBase::CorpusBase(Ptr options, bool translate) // training or scoring if(training) { + // Marian can create vocabularies automatically if no vocabularies are given or they do not + // exists under the specified paths. + // + // Possible cases: + // * -t train1 train2 -v vocab1 vocab2 + // If vocab1 or vocab2 exists, they are loaded, otherwise separate .yml vocabularies are + // created only from train1 or train2 respectively. + // + // * -t train1 train2 -v vocab vocab + // If vocab exists, it is loaded, otherwise it is created from concatenated train1 and train2 + // files. + // + // * -t train1 train2 + // If no path is given, separate vocabularies train1.yml and train2.yml are created from + // train1 and train2 respectively. + // + // * --tsv -t train.tsv -v vocab1 vocab2 + // If vocab1 or vocab2 exists, it is loaded; otherwise each vocabulary is created from the + // appropriate fields in train.tsv. + // + // * --tsv -t train.tsv -v vocab vocab + // If vocab exist, it is loaded; otherwise it is created from all fields in train.tsv. + // + // * --tsv -t train.tsv + // If no path is given, a train.tsv.yml is created from all fields in train.tsv. + // + // * cat file.tsv | --tsv -t stdin -v vocab1 vocab2 + // If either vocab1 or vocab2 does not exist, an error is shown that creation of vocabularies + // from stdin is not supported. + // + // * cat file.tsv | --tsv -t stdin -v vocab vocab + // If vocab does not exist, an error is shown that creation of a vocabulary from stdin is not + // supported. + // + // * cat file.tsv | --tsv -t stdin + // As above, an error is shown that creation of a vocabulary from stdin is not supported. + // + // There is more cases for multi-encoder models not listed above. + // if(vocabPaths.empty()) { + size_t numStreams = tsv_ ? tsvNumFields_ : paths_.size(); + + // Creating a vocabulary from stdin is not supported + ABORT_IF(tsv_ && (paths_[0] == "stdin" || paths_[0] == "-"), + "Creating vocabularies automatically from a data stream from STDIN is not supported. " + "Create vocabularies first and provide them with --vocabs"); + if(maxVocabs.size() < paths_.size()) maxVocabs.resize(paths_.size(), 0); - LOG(info, "No vocabulary files given, trying to find or build based on training data. " - "Vocabularies will be built separately for each file."); + LOG(info, + "[data] No vocabulary files given, trying to find or build based on training data."); + if(!tsv_) + LOG(info, "[data] Vocabularies will be built separately for each file."); + else + LOG(info, "[data] A joint vocabulary will be built from the TSV file."); + + std::vector vocabDims(numStreams, 0); + std::vector vocabPaths1(numStreams); - std::vector vocabDims(paths_.size(), 0); - std::vector vocabPaths1(paths_.size()); // Create vocabs if not provided - for(size_t i = 0; i < paths_.size(); ++i) { + for(size_t i = 0; i < numStreams; ++i) { Ptr vocab = New(options_, i); - std::vector trainPaths = { paths_[i] }; + + const auto& path = paths_[tsv_ ? 0 : i]; // idx 0 because there is always only one TSV file + std::vector trainPaths = {path}; + vocabPaths1[i] = path + ".yml"; + vocabDims[i] = (int) vocab->loadOrCreate("", trainPaths, maxVocabs[i]); - vocabPaths1[i] = paths_[i] + ".yml"; vocabs_.emplace_back(vocab); } // TODO: this is not nice as it modifies the option object and needs to expose the changes // outside the corpus as models need to know about the vocabulary size; extract the vocab // creation functionality from the class. options_->set("dim-vocabs", vocabDims, "vocabs", vocabPaths1); - } else { + + } else { // Vocabulary paths are given + size_t numStreams = tsv_ ? tsvNumFields_ : paths_.size(); + // Load all vocabs size_t numVocs = vocabPaths.size(); if(maxVocabs.size() < numVocs) - maxVocabs.resize(paths_.size(), 0); + maxVocabs.resize(numStreams, 0); - // Helper object to for grouping training data based on vocabulary file name - struct PathsAndSize { - std::set paths; // contains all paths that are used for training the vocabulary - size_t size; // contains the maximum vocabulary size + // Helper object for grouping training data based on vocabulary file name + struct VocabDetails { + std::set paths; // all paths that are used for training the vocabulary + std::vector streams; // index of the vocabulary in the --vocab option + size_t size; // the maximum vocabulary size }; // Group training files based on vocabulary path. If the same // vocab path corresponds to different training files, this means // that a single vocab should combine tokens from all files. - std::map groupVocab; + std::map groupVocab; // vocabPath -> (trainPaths[], vocabSize) for(size_t i = 0; i < numVocs; ++i) { - groupVocab[vocabPaths[i]].paths.insert(paths_[i]); + // Index 0 because there is always only a single TSV input file + groupVocab[vocabPaths[i]].paths.insert(paths_[tsv_ ? 0 : i]); + groupVocab[vocabPaths[i]].streams.push_back(i); if(groupVocab[vocabPaths[i]].size < maxVocabs[i]) groupVocab[vocabPaths[i]].size = maxVocabs[i]; } auto vocabDims = options_->get>("dim-vocabs"); - vocabDims.resize(numVocs, 0); + vocabDims.resize(numVocs, 0); // make sure there is as many dims as vocab paths + for(size_t i = 0; i < numVocs; ++i) { - Ptr vocab = New(options_, i); + // Creating a vocabulary from stdin is not supported + ABORT_IF(tsv_ && (paths_[0] == "stdin" || paths_[0] == "-") + && (vocabPaths[i].empty() || !filesystem::exists(vocabPaths[i])), + "Creating vocabulary automatically from a data stream from STDIN is not supported. " + "Create vocabularies first and provide them with --vocabs"); // Get the set of files that corresponds to the vocab. If the next file is the same vocab, - // it wild not be created again, but just correctly loaded. - auto pathsAndSize = groupVocab[vocabPaths[i]]; - std::vector groupedPaths(pathsAndSize.paths.begin(), pathsAndSize.paths.end()); - vocabDims[i] = (int) vocab->loadOrCreate(vocabPaths[i], groupedPaths, pathsAndSize.size); + // it will not be created again, but just correctly loaded. + auto vocabDetails = groupVocab[vocabPaths[i]]; + std::vector groupedPaths(vocabDetails.paths.begin(), vocabDetails.paths.end()); + Ptr tsvTempFile; // temporary handler for cut fields from TSV input + + // For a TSV input, multiple vocabularies with different names mean separate + // vocabularies for source(s) and target. + // If a vocabulary does not exist, it will be created in the next step. To be able to create + // a separate vocabulary, we cut tab-separated field(s) from the TSV file, e.g. all source + // or target sentences, into a temporary file. + if(tsv_ && groupVocab.size() > 1 && !filesystem::exists(vocabPaths[i])) { + ABORT_IF(groupedPaths.size() > 1, "There should not be multiple TSV input files!"); + + tsvTempFile.reset(new io::TemporaryFile(options_->get("tempdir"), false)); + LOG(info, + "[data] Cutting field(s) {} from {} into a temporary file {}", + utils::join(vocabDetails.streams, ", "), + groupedPaths[0], + tsvTempFile->getFileName()); + + fileutils::cut(groupedPaths[0], // Index 0 because there is only one TSV file + tsvTempFile, + vocabDetails.streams, + tsvNumFields_, + " "); // Notice that tab-separated fields are joined with a whitespace + + groupedPaths.clear(); + groupedPaths.push_back(tsvTempFile->getFileName()); + } + + // Load or create the vocabulary + Ptr vocab = New(options_, i); + vocabDims[i] = (int) vocab->loadOrCreate(vocabPaths[i], groupedPaths, vocabDetails.size); vocabs_.emplace_back(vocab); + + if(tsvTempFile) + tsvTempFile.reset(); } // TODO: this is not nice as it modifies the option object and needs to expose the changes // outside the corpus as models need to know about the vocabulary size; extract the vocab @@ -140,8 +253,7 @@ CorpusBase::CorpusBase(Ptr options, bool translate) } if(translate) { - ABORT_IF(vocabPaths.empty(), - "Translating, but vocabularies are not given!"); + ABORT_IF(vocabPaths.empty(), "Translating, but vocabularies are not given!"); size_t numVocs = vocabPaths.size(); if(maxVocabs.size() < numVocs) @@ -161,7 +273,7 @@ CorpusBase::CorpusBase(Ptr options, bool translate) } for(auto path : paths_) { - if(path == "stdin") + if(path == "stdin" || path == "-") files_.emplace_back(new std::istream(std::cin.rdbuf())); else { io::InputFileStream *strm = new io::InputFileStream(path); @@ -170,7 +282,7 @@ CorpusBase::CorpusBase(Ptr options, bool translate) } } - ABORT_IF(vocabs_.size() != files_.size(), + ABORT_IF(!tsv_ && vocabs_.size() != files_.size(), "Number of {} files ({}) and vocab files ({}) does not agree", training ? "corpus" : "input", files_.size(), @@ -206,7 +318,6 @@ CorpusBase::CorpusBase(Ptr options, bool translate) void CorpusBase::addWordsToSentenceTuple(const std::string& line, size_t batchIndex, SentenceTuple& tup) const { - // This turns a string in to a sequence of numerical word ids. Depending // on the vocabulary type, this can be non-trivial, e.g. when SentencePiece // is used. @@ -298,26 +409,28 @@ void CorpusBase::addWeightsToBatch(Ptr batch, void CorpusBase::initEOS(bool training = true) { // Labels fed into sub-batches that are just class-labels, not sequence labels do not require to - // add a EOS symbol. Hence decision to add EOS is now based on input stream positions and correspoding - // input type. + // add a EOS symbol. Hence decision to add EOS is now based on input stream positions and + // correspoding input type. + + size_t numStreams = tsv_ ? tsvNumFields_ : paths_.size(); // determine number of streams - addEOS_.resize(paths_.size(), true); + addEOS_.resize(numStreams, true); // @TODO: think if this should be checked and processed here or in a validation step in config? auto inputTypes = options_->get>("input-types", {}); // empty list by default - // make sure there is an input type for each path - ABORT_IF(inputTypes.size() > 0 && inputTypes.size() < paths_.size(), + // make sure there is an input type for each stream + ABORT_IF(inputTypes.size() > 0 && inputTypes.size() < numStreams, "Input types have been specified ({}), you need to specify one per input ({})", inputTypes.size(), - paths_.size()); + numStreams); - // make sure there is an equal number of input types and paths when training - ABORT_IF(training && inputTypes.size() > 0 && inputTypes.size() != paths_.size(), + // make sure there is an equal number of input types and streams when training + ABORT_IF(training && inputTypes.size() > 0 && inputTypes.size() != numStreams, "Input types have been specified ({}), you need to specify one per input ({})", inputTypes.size(), - paths_.size()); + numStreams); - for(int i = 0; i < paths_.size(); ++i) + for(int i = 0; i < numStreams; ++i) if(inputTypes.size() > i) { if(inputTypes[i] == "class") addEOS_[i] = false; diff --git a/src/data/corpus_base.h b/src/data/corpus_base.h index e85e378ab..5efd2211b 100755 --- a/src/data/corpus_base.h +++ b/src/data/corpus_base.h @@ -533,7 +533,7 @@ class CorpusBase std::vector> vocabs_; /** - * brief Determines if a EOS symbol should be added. By default this is true for any sequence, + * @brief Determines if a EOS symbol should be added. By default this is true for any sequence, * but should be false for instance for classifier labels. This is set per input stream, hence a * vector. */ @@ -545,6 +545,9 @@ class CorpusBase bool maxLengthCrop_{false}; bool rightLeft_{false}; + bool tsv_{false}; // true if the input is a single file with tab-separated values + size_t tsvNumFields_{0}; // number of fields in the TSV input (only if tsv_) + /** * @brief Index of the file with weights in paths_ and files_; zero means no * weights file provided. diff --git a/src/data/default_vocab.cpp b/src/data/default_vocab.cpp index 590e9931b..30bf219fe 100644 --- a/src/data/default_vocab.cpp +++ b/src/data/default_vocab.cpp @@ -217,7 +217,6 @@ class DefaultVocab : public IVocab { std::string line; while(getline(*trainStrm, line)) { auto toks = utils::split(line, " "); - for(const std::string& tok : toks) { auto iter = counter.find(tok); if(iter == counter.end()) diff --git a/src/tests/units/CMakeLists.txt b/src/tests/units/CMakeLists.txt index 3814b4810..654355b7f 100644 --- a/src/tests/units/CMakeLists.txt +++ b/src/tests/units/CMakeLists.txt @@ -5,6 +5,7 @@ set(UNIT_TESTS rnn_tests attention_tests fastopt_tests + utils_tests ) foreach(test ${UNIT_TESTS}) diff --git a/src/tests/units/utils_tests.cpp b/src/tests/units/utils_tests.cpp new file mode 100644 index 000000000..9b21f511c --- /dev/null +++ b/src/tests/units/utils_tests.cpp @@ -0,0 +1,36 @@ +#include "catch.hpp" +#include "common/utils.h" + +using namespace marian; + +TEST_CASE("utils::splitTsv", "[utils]") { + std::string line1 = "foo bar"; + std::string line2 = "foo bar\tbazz"; + std::string line3 = "foo bar\tbazz\tfoo quux"; + + std::vector fields; + + SECTION("the tab-separated input is split") { + utils::splitTsv(line1, fields, 1); + CHECK( fields.size() == 1 ); + CHECK( fields[0] == "foo bar" ); + + utils::splitTsv(line3, fields, 3); + CHECK( fields == std::vector({"foo bar", "bazz", "foo quux"}) ); + } + + SECTION("the output has at least as many elements as requested") { + utils::splitTsv(line1, fields, 1); + CHECK( fields.size() == 1 ); + + utils::splitTsv(line1, fields, 3); + CHECK( fields.size() == 3 ); + CHECK( fields == std::vector({"foo bar", "", ""}) ); + + utils::splitTsv(line1, fields, 2); + CHECK( fields.size() == 2 ); + CHECK( fields == std::vector({"foo bar", ""}) ); + } + + //SECTION("excessive tab-separated fields abort the execution") {} +} diff --git a/src/training/graph_group.h b/src/training/graph_group.h index 012f78ef9..83873edab 100644 --- a/src/training/graph_group.h +++ b/src/training/graph_group.h @@ -55,7 +55,88 @@ class GraphGroup { Ptr collectStats(Ptr graph, Ptr model, const std::vector>& vocabs, - double multiplier = 1.); + double multiplier = 1.) { + auto stats = New(); + + size_t numFiles = options_->get("tsv", false) + ? options_->get("tsv-fields") + : options_->get>("train-sets").size(); + + // Initialize first batch to step size + size_t first = options_->get("mini-batch-fit-step"); + + // Increase batch size and sentence length by this step size + size_t step = options_->get("mini-batch-fit-step"); + + size_t maxLength = options_->get("max-length"); + maxLength = (size_t)(std::ceil(maxLength / (float)step) * step); + + // this should be only one class label per line on input, hence restricting length to 1 + std::vector localMaxes(numFiles, maxLength); + auto inputTypes = options_->get>("input-types", {}); + for(int i = 0; i < inputTypes.size(); ++i) + if(inputTypes[i] == "class") + localMaxes[i] = 1; + + size_t maxBatch = 512; + bool fits = true; + while(fits) { + std::vector lengths(numFiles, first); + for(int j = 0; j < lengths.size(); ++j) // apply length restrictions + lengths[j] = std::min(lengths[j], localMaxes[j]); + + auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, maxBatch, options_); + auto cost = model->build(graph, batch); + fits = graph->fits(); + if(fits) + maxBatch *= 2; + } + + // Do a binary search for maxmimum batch size that fits into given workspace memory + // for a tested sentence length. + for(size_t i = step; i <= maxLength; i += step) { + size_t start = 1; + size_t end = maxBatch; + + std::vector lengths(numFiles, i); + for(int j = 0; j < lengths.size(); ++j) // apply length restrictions + lengths[j] = std::min(lengths[j], localMaxes[j]); + fits = true; + + do { + size_t current = (start + end) / 2; + auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, current, options_); + auto cost = model->build(graph, batch); + fits = graph->fits(); + + LOG(debug, "[batching] length: {} - size: {} - fits: {}", lengths[0], current, fits); + + if(fits) { + stats->add(batch, multiplier); + start = current + 1; + } else { + end = current - 1; + } + } while(end - start > step); + + maxBatch = start; + } + return stats; + } + + void setTypicalTrgBatchWords(size_t typicalTrgBatchWords) { // needed for dynamic MB scaling + typicalTrgBatchWords_ = typicalTrgBatchWords; + } +}; + +/** + * Base class for multi-node versions of GraphGroups. + */ +class MultiNodeGraphGroupBase : public GraphGroup { + using Base = GraphGroup; + +protected: + Ptr mpi_; // all MPI-like communication goes through this void setTypicalTrgBatchWords(size_t typicalTrgBatchWords); }; diff --git a/src/training/scheduler.h b/src/training/scheduler.h index 651c34b31..2ec9f1ab0 100755 --- a/src/training/scheduler.h +++ b/src/training/scheduler.h @@ -15,11 +15,16 @@ class Scheduler : public TrainingObserver { Ptr state_; std::vector> validators_; - bool first_{true}; + bool first_{true}; // true if this is the first update after renewing the training timer::Timer timer_; timer::Timer heartBeatTimer_; + // The variable helps to keep track of the end of the current epoch + // (regardless if it's the 1st or nth epoch and if it's a new or continued training), + // which indicates the end of the training data stream from STDIN + bool endOfStdin_{false}; // true at the end of the epoch if training from STDIN; + // determine scheduled LR decay factor (--lr-decay-inv-sqrt option) float getScheduledLRDecayFactor(const TrainingState& state) const { auto args = options_->get>("lr-decay-inv-sqrt"); @@ -150,7 +155,6 @@ class Scheduler : public TrainingObserver { } bool keepGoing() { - if(getSignalFlag(SIGTERM)) // received signal SIGERM => exit gracefully return false; @@ -170,6 +174,10 @@ class Scheduler : public TrainingObserver { && stalled() >= stopAfterStalled) return false; + // stop if data streaming from STDIN is stopped + if(endOfStdin_) + return false; + return true; } @@ -402,6 +410,11 @@ class Scheduler : public TrainingObserver { } void actAfterEpoch(TrainingState& state) override { + // stop if data streaming from STDIN is stopped for a TSV input + std::string firstPath = options_->get>("train-sets")[0]; + if(options_->get("tsv", false) && (firstPath == "stdin" || firstPath == "-")) + endOfStdin_ = true; + float factor = options_->get("lr-decay"); updateLearningRate(state); diff --git a/src/training/training_state.h b/src/training/training_state.h index c1ddfc837..f8b9e5632 100644 --- a/src/training/training_state.h +++ b/src/training/training_state.h @@ -14,7 +14,7 @@ class TrainingState; class TrainingObserver { public: virtual ~TrainingObserver() {} - + virtual void init(TrainingState&) {} virtual void actAfterEpoch(TrainingState&) {} virtual void actAfterBatches(TrainingState&) {} From c95676e081da4d488560b38fb0d01bba47272b66 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Fri, 10 Apr 2020 13:50:22 -0700 Subject: [PATCH 09/62] bump version --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index b95e90dc7..c4e620172 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v1.9.2 +v1.9.4 From 71cc43a2ff19f15f9cfe5260536881321aba036c Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Fri, 10 Apr 2020 13:53:21 -0700 Subject: [PATCH 10/62] actually save the merge file --- src/training/graph_group.h | 83 +------------------------------------- 1 file changed, 1 insertion(+), 82 deletions(-) diff --git a/src/training/graph_group.h b/src/training/graph_group.h index 83873edab..012f78ef9 100644 --- a/src/training/graph_group.h +++ b/src/training/graph_group.h @@ -55,88 +55,7 @@ class GraphGroup { Ptr collectStats(Ptr graph, Ptr model, const std::vector>& vocabs, - double multiplier = 1.) { - auto stats = New(); - - size_t numFiles = options_->get("tsv", false) - ? options_->get("tsv-fields") - : options_->get>("train-sets").size(); - - // Initialize first batch to step size - size_t first = options_->get("mini-batch-fit-step"); - - // Increase batch size and sentence length by this step size - size_t step = options_->get("mini-batch-fit-step"); - - size_t maxLength = options_->get("max-length"); - maxLength = (size_t)(std::ceil(maxLength / (float)step) * step); - - // this should be only one class label per line on input, hence restricting length to 1 - std::vector localMaxes(numFiles, maxLength); - auto inputTypes = options_->get>("input-types", {}); - for(int i = 0; i < inputTypes.size(); ++i) - if(inputTypes[i] == "class") - localMaxes[i] = 1; - - size_t maxBatch = 512; - bool fits = true; - while(fits) { - std::vector lengths(numFiles, first); - for(int j = 0; j < lengths.size(); ++j) // apply length restrictions - lengths[j] = std::min(lengths[j], localMaxes[j]); - - auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, maxBatch, options_); - auto cost = model->build(graph, batch); - fits = graph->fits(); - if(fits) - maxBatch *= 2; - } - - // Do a binary search for maxmimum batch size that fits into given workspace memory - // for a tested sentence length. - for(size_t i = step; i <= maxLength; i += step) { - size_t start = 1; - size_t end = maxBatch; - - std::vector lengths(numFiles, i); - for(int j = 0; j < lengths.size(); ++j) // apply length restrictions - lengths[j] = std::min(lengths[j], localMaxes[j]); - fits = true; - - do { - size_t current = (start + end) / 2; - auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, current, options_); - auto cost = model->build(graph, batch); - fits = graph->fits(); - - LOG(debug, "[batching] length: {} - size: {} - fits: {}", lengths[0], current, fits); - - if(fits) { - stats->add(batch, multiplier); - start = current + 1; - } else { - end = current - 1; - } - } while(end - start > step); - - maxBatch = start; - } - return stats; - } - - void setTypicalTrgBatchWords(size_t typicalTrgBatchWords) { // needed for dynamic MB scaling - typicalTrgBatchWords_ = typicalTrgBatchWords; - } -}; - -/** - * Base class for multi-node versions of GraphGroups. - */ -class MultiNodeGraphGroupBase : public GraphGroup { - using Base = GraphGroup; - -protected: - Ptr mpi_; // all MPI-like communication goes through this + double multiplier = 1.); void setTypicalTrgBatchWords(size_t typicalTrgBatchWords); }; From 09904e0f023c7b4c7334655dd7990e60a5c140f7 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Fri, 10 Apr 2020 15:27:34 -0700 Subject: [PATCH 11/62] use float values for catch::Approx --- src/tests/units/attention_tests.cpp | 2 +- src/tests/units/operator_tests.cpp | 4 ++-- src/tests/units/rnn_tests.cpp | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/tests/units/attention_tests.cpp b/src/tests/units/attention_tests.cpp index e13e7943d..4fbed7b52 100644 --- a/src/tests/units/attention_tests.cpp +++ b/src/tests/units/attention_tests.cpp @@ -23,7 +23,7 @@ void tests(DeviceType type, Type floatType = Type::float32) { } #endif - auto floatApprox = [](T x, T y) { return x == Approx(y).epsilon(0.01); }; + auto floatApprox = [](T x, T y) { return x == Approx(y).epsilon(0.01f).scale(1.f); }; Config::seed = 1234; diff --git a/src/tests/units/operator_tests.cpp b/src/tests/units/operator_tests.cpp index 682ef4805..8403a84ea 100644 --- a/src/tests/units/operator_tests.cpp +++ b/src/tests/units/operator_tests.cpp @@ -22,7 +22,7 @@ void tests(DeviceType device, Type floatType = Type::float32) { } #endif - auto floatApprox = [](T x, T y) -> bool { return x == Approx(y).epsilon(0.01); }; + auto floatApprox = [](T x, T y) -> bool { return x == Approx(y).epsilon(0.01f).scale(1.f); }; auto floatEqual = [](T x, T y) -> bool { return x == y; }; Config::seed = 1234; @@ -794,7 +794,7 @@ TEST_CASE("Expression graph supports basic math operations (cpu)", "[operator]") #ifdef CUDA_FOUND TEST_CASE("Compare aggregate operator", "[graph]") { - auto floatApprox = [](float x, float y) -> bool { return x == Approx(y).epsilon(0.01); }; + auto floatApprox = [](float x, float y) -> bool { return x == Approx(y).epsilon(0.01f).scale(1.f); }; Config::seed = 1234; diff --git a/src/tests/units/rnn_tests.cpp b/src/tests/units/rnn_tests.cpp index 56a2d1fdd..6405ef7a6 100644 --- a/src/tests/units/rnn_tests.cpp +++ b/src/tests/units/rnn_tests.cpp @@ -22,7 +22,7 @@ void tests(DeviceType type, Type floatType = Type::float32) { } #endif - auto floatApprox = [](T x, T y) { return x == Approx(y).epsilon(0.01); }; + auto floatApprox = [](T x, T y) { return x == Approx(y).epsilon(0.01f).scale(1.f); }; std::vector vWords = { 43, 2, 83, 78, From 4d12ffa96c9335e715a39cfa9017d48d2b8a39a3 Mon Sep 17 00:00:00 2001 From: Roman Grundkiewicz Date: Sat, 11 Apr 2020 16:04:20 +0100 Subject: [PATCH 12/62] Fix TSV training with mini-batch-fit after the last merge --- CHANGELOG.md | 6 ++++-- VERSION | 2 +- src/training/graph_group.cpp | 6 ++++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 15c5c7387..62e1365b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] +### Added +- Training and scoring from STDIN +- Support for tab-separated inputs, added options --tsv and --tsv-fields + ### Changed - Changed compile flags -Ofast to -O3 and remove --ffinite-math - Moved old graph groups to depracated folder @@ -18,8 +22,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [1.9.0] - 2020-03-10 ### Added -- Training and scoring from STDIN -- Support for tab-separated inputs, added ptions --tsv and --tsv-fields - An option to print cached variables from CMake - Add support for compiling on Mac (and clang) - An option for resetting stalled validation metrics diff --git a/VERSION b/VERSION index c4e620172..57d503d1e 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v1.9.4 +v1.9.5 diff --git a/src/training/graph_group.cpp b/src/training/graph_group.cpp index 8950521cd..616bb9911 100644 --- a/src/training/graph_group.cpp +++ b/src/training/graph_group.cpp @@ -18,7 +18,9 @@ Ptr GraphGroup::collectStats(Ptr graph, double multiplier) { auto stats = New(); - size_t numFiles = options_->get>("train-sets").size(); + size_t numFiles = options_->get("tsv", false) + ? options_->get("tsv-fields") + : options_->get>("train-sets").size(); // Initialize first batch to step size size_t first = options_->get("mini-batch-fit-step"); @@ -86,4 +88,4 @@ void GraphGroup::setTypicalTrgBatchWords(size_t typicalTrgBatchWords) { // neede typicalTrgBatchWords_ = typicalTrgBatchWords; } -} \ No newline at end of file +} From 855c94a55daa547041a9bc6dfbe9667022aa5ec5 Mon Sep 17 00:00:00 2001 From: Roman Grundkiewicz Date: Sat, 11 Apr 2020 16:06:34 +0100 Subject: [PATCH 13/62] Update submodule regression-tests --- regression-tests | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/regression-tests b/regression-tests index 6a08849b2..5cfede4cd 160000 --- a/regression-tests +++ b/regression-tests @@ -1 +1 @@ -Subproject commit 6a08849b23f6c14eefbe12f4eb73dc638b962587 +Subproject commit 5cfede4cde26479903aa29edff779ecec14bcd85 From c18fc71e8cb3a7e8bed1c5a84b66fd22d2b6843e Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Sat, 11 Apr 2020 09:23:56 -0700 Subject: [PATCH 14/62] fix 0 * nan behavior in concatention --- src/tensors/cpu/tensor_operators.cpp | 45 ++++++++++++++-------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/src/tensors/cpu/tensor_operators.cpp b/src/tensors/cpu/tensor_operators.cpp index ae5eeed57..5aa695204 100755 --- a/src/tensors/cpu/tensor_operators.cpp +++ b/src/tensors/cpu/tensor_operators.cpp @@ -77,6 +77,7 @@ void ConcatCont(Tensor out, const std::vector& inputs, int axis) { } } +template inline void gInsertCols(float* out, const float* in, size_t rows, @@ -84,13 +85,15 @@ inline void gInsertCols(float* out, size_t cols_out, size_t cols_in, size_t offset_out, - size_t offset_in, - float beta) { + size_t offset_in) { for(size_t j = 0; j < rows; ++j) { float* rowOut = out + j * cols_out + offset_out; const float* rowIn = in + j * cols_in + offset_in; for(size_t i = 0; i < cols; ++i) { - rowOut[i] = rowIn[i] + beta * rowOut[i]; + if(add) // this was solved earlier via beta * rowOut[i] with beta in {0,1} but 0 * nan in uninitialized tensors will result in nan. + rowOut[i] += rowIn[i]; + else + rowOut[i] = rowIn[i]; } } } @@ -105,21 +108,20 @@ void Concatenate1(Tensor out, const std::vector& inputs) { ABORT_IF(rows != in->shape().elements() / in->shape().back(), "First dimension must be equal"); int cols_in = in->shape().back(); - cpu::gInsertCols(out->data(), - in->data(), - rows, - cols_in, - cols_out, - cols_in, - offset, - 0, - 0); + cpu::gInsertCols(out->data(), + in->data(), + rows, + cols_in, + cols_out, + cols_in, + offset, + 0); offset += cols_in; } } void Concatenate(Tensor out, const std::vector& inputs, int ax) { - if(ax == (int)out->shape().size() - 1) + if(ax == (int)out->shape().size() - 1) Concatenate1(out, inputs); else ConcatCont(out, inputs, ax); @@ -136,15 +138,14 @@ void Split1(std::vector& outputs, const Tensor in) { // set last parameter to 1 to enable += instead of = // @TODO: do this in a more principled ways accross all/most kernels - cpu::gInsertCols(out->data(), - in->data(), - rows, - cols_out, - cols_out, - cols_in, - 0, - offset, - 1); + cpu::gInsertCols(out->data(), + in->data(), + rows, + cols_out, + cols_out, + cols_in, + 0, + offset); offset += cols_out; } } From 0ba438c463b32831eeae3901b28da0d0cc5bf146 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Sat, 11 Apr 2020 09:45:57 -0700 Subject: [PATCH 15/62] Fix 0 * nan behavior due to using -O3 instead of -OFast (#630) * fix 0 * nan behavior in concatention * bump patch * change epsilon to margin --- CHANGELOG.md | 8 +++++++- src/tests/units/attention_tests.cpp | 2 +- src/tests/units/operator_tests.cpp | 4 ++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 62e1365b2..40bee3d70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,13 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ### Added - Training and scoring from STDIN -- Support for tab-separated inputs, added options --tsv and --tsv-fields +- Support for reading from TSV files from STDIN and other sources during training + and translation with options --tsv and --tsv-fields n. + +### Fixed +- In concatenation make sure that we do not multiply 0 with nan (which results in nan) +- Change Approx.epsilon(0.01) to Approx.margin(0.001) in unit tests. Tolerance is now + absolute and not relative. We assumed incorrectly that epsilon is absolute tolerance. ### Changed - Changed compile flags -Ofast to -O3 and remove --ffinite-math diff --git a/src/tests/units/attention_tests.cpp b/src/tests/units/attention_tests.cpp index 4fbed7b52..fe11bf2f0 100644 --- a/src/tests/units/attention_tests.cpp +++ b/src/tests/units/attention_tests.cpp @@ -23,7 +23,7 @@ void tests(DeviceType type, Type floatType = Type::float32) { } #endif - auto floatApprox = [](T x, T y) { return x == Approx(y).epsilon(0.01f).scale(1.f); }; + auto floatApprox = [](T x, T y) { return x == Approx(y).margin(0.001f); }; Config::seed = 1234; diff --git a/src/tests/units/operator_tests.cpp b/src/tests/units/operator_tests.cpp index 8403a84ea..581cd05c7 100644 --- a/src/tests/units/operator_tests.cpp +++ b/src/tests/units/operator_tests.cpp @@ -22,7 +22,7 @@ void tests(DeviceType device, Type floatType = Type::float32) { } #endif - auto floatApprox = [](T x, T y) -> bool { return x == Approx(y).epsilon(0.01f).scale(1.f); }; + auto floatApprox = [](T x, T y) -> bool { return x == Approx(y).margin(0.001f); }; auto floatEqual = [](T x, T y) -> bool { return x == y; }; Config::seed = 1234; @@ -794,7 +794,7 @@ TEST_CASE("Expression graph supports basic math operations (cpu)", "[operator]") #ifdef CUDA_FOUND TEST_CASE("Compare aggregate operator", "[graph]") { - auto floatApprox = [](float x, float y) -> bool { return x == Approx(y).epsilon(0.01f).scale(1.f); }; + auto floatApprox = [](float x, float y) -> bool { return x == Approx(y).margin(0.001f); }; Config::seed = 1234; From 93a27dcdd25d7da126ebda5290c579be7bf68974 Mon Sep 17 00:00:00 2001 From: Roman Grundkiewicz Date: Sat, 11 Apr 2020 18:47:17 +0100 Subject: [PATCH 16/62] Update submodule regression-tests --- regression-tests | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/regression-tests b/regression-tests index 5cfede4cd..67281c736 160000 --- a/regression-tests +++ b/regression-tests @@ -1 +1 @@ -Subproject commit 5cfede4cde26479903aa29edff779ecec14bcd85 +Subproject commit 67281c736fcffb074e35665fe6c52be9a4cf5ca8 From 733cb505bc7353635ee02fdddc7eb9b6465d976b Mon Sep 17 00:00:00 2001 From: Roman Grundkiewicz Date: Sun, 12 Apr 2020 18:56:11 +0100 Subject: [PATCH 17/62] Support relative paths in shortlist and sqlite options (#612) * Refactorize processPaths * Fix relative paths for shortlist and sqlite options * Rename InterpolateEnvVars to interpolateEnvVars * Update CHANGELOG --- CHANGELOG.md | 3 +- src/common/cli_helper.h | 57 +++++++++++++++++------------------- src/common/config_parser.cpp | 22 +++++++------- 3 files changed, 39 insertions(+), 43 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 40bee3d70..df1b05c79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,8 +9,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] ### Added +- Supporting relative paths in shortlist and sqlite options - Training and scoring from STDIN -- Support for reading from TSV files from STDIN and other sources during training +- Support for reading from TSV files from STDIN and other sources during training and translation with options --tsv and --tsv-fields n. ### Fixed diff --git a/src/common/cli_helper.h b/src/common/cli_helper.h index 4477f0c02..dc8eafdf6 100644 --- a/src/common/cli_helper.h +++ b/src/common/cli_helper.h @@ -10,7 +10,7 @@ namespace cli { // helper to replace environment-variable expressions of the form ${VARNAME} in // a string -static inline std::string InterpolateEnvVars(std::string str) { +static inline std::string interpolateEnvVars(std::string str) { // temporary workaround for MS-internal PhillyOnAzure cluster: warm storage // presently has the form /hdfs/VC instead of /{gfs,hdfs}/CLUSTER/VC @@ -58,43 +58,40 @@ static inline std::string InterpolateEnvVars(std::string str) { } } -// helper to implement interpolate-env-vars and relative-paths options +// Helper to implement interpolate-env-vars and relative-paths options static inline void processPaths( YAML::Node& node, const std::function& TransformPath, const std::set& PATHS, - bool isPath = false) { - if(isPath) { - if(node.Type() == YAML::NodeType::Scalar) { - std::string nodePath = node.as(); - // transform the path - if(!nodePath.empty()) - node = TransformPath(nodePath); - } + bool isPath = false, + const std::string parentKey = "") { + // For a scalar node (leaves in the config), just transform the path + if(isPath && node.IsScalar()) { + std::string nodePath = node.as(); + if(!nodePath.empty()) + node = TransformPath(nodePath); + } + // For a sequence node, recursively iterate each value + else if(node.IsSequence()) { + for(auto&& sub : node) { + processPaths(sub, TransformPath, PATHS, isPath); - if(node.Type() == YAML::NodeType::Sequence) { - for(auto&& sub : node) { - processPaths(sub, TransformPath, PATHS, true); - } - } - } else { - switch(node.Type()) { - case YAML::NodeType::Sequence: - for(auto&& sub : node) { - processPaths(sub, TransformPath, PATHS, false); - } - break; - case YAML::NodeType::Map: - for(auto&& sub : node) { - std::string key = sub.first.as(); - processPaths(sub.second, TransformPath, PATHS, PATHS.count(key) > 0); - } - break; - default: - // it is OK + // Exception for the shortlist option, which keeps a path and three numbers; + // we want to process the path only and keep the rest untouched + if(isPath && parentKey == "shortlist") break; } } + // For a map node that is not a path, recursively iterate each value + else if(!isPath && node.IsMap()) { + for(auto&& sub : node) { + std::string key = sub.first.as(); + // Exception for the sqlite option, which has a special value of 'temporary' + if(key == "sqlite" && sub.second.as() == "temporary") + continue; + processPaths(sub.second, TransformPath, PATHS, PATHS.count(key) > 0, key); + } + } } // helper to convert a YAML node recursively into a string diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index 05a6ccb1d..2f56d8870 100755 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -39,16 +39,14 @@ const std::set PATHS = { "valid-script-args", "valid-log", "valid-translation-output", - "input", // except: stdin - "output", // except: stdout + "input", // except: 'stdin', handled in makeAbsolutePaths and interpolateEnvVars + "output", // except: 'stdout', handled in makeAbsolutePaths and interpolateEnvVars "pretrained-model", "data-weighting", - "log" - // TODO: Handle the special value in helper functions - //"sqlite", // except: temporary - // TODO: This is a vector with a path and some numbers, handle this in helper - // functions or separate shortlist path to a separate command-line option - //"shortlist", + "log", + "sqlite", // except: 'temporary', handled in the processPaths function + "shortlist", // except: only the first element in the sequence is a path, handled in the + // processPaths function }; // clang-format on @@ -876,7 +874,7 @@ Ptr ConfigParser::parseOptions(int argc, char** argv, bool doValidate){ } if(get("interpolate-env-vars")) { - cli::processPaths(config_, cli::InterpolateEnvVars, PATHS); + cli::processPaths(config_, cli::interpolateEnvVars, PATHS); } // Option shortcuts for input from STDIN for trainer and scorer @@ -931,12 +929,12 @@ std::vector ConfigParser::findConfigPaths() { for(auto& path : paths) { // (note: this updates the paths array) if(interpolateEnvVars) - path = cli::InterpolateEnvVars(path); + path = cli::interpolateEnvVars(path); } } else if(mode_ == cli::mode::training) { auto path = config_["model"].as() + ".yml"; if(interpolateEnvVars) - path = cli::InterpolateEnvVars(path); + path = cli::interpolateEnvVars(path); bool reloadConfig = filesystem::exists(path) && !get("no-reload"); if(reloadConfig) @@ -962,7 +960,7 @@ YAML::Node ConfigParser::loadConfigFiles(const std::vector& paths) && config["interpolate-env-vars"].as()) || get("interpolate-env-vars"); if(interpolateEnvVars) - cli::processPaths(config, cli::InterpolateEnvVars, PATHS); + cli::processPaths(config, cli::interpolateEnvVars, PATHS); // replace relative path w.r.t. the config file cli::makeAbsolutePaths(config, path, PATHS); From 7bf486ad61232b7d0294f8d8a12eca72547e0e97 Mon Sep 17 00:00:00 2001 From: Roman Grundkiewicz Date: Sun, 12 Apr 2020 18:58:33 +0100 Subject: [PATCH 18/62] Fix Iris example on CPU (#623) --- src/examples/iris/iris.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/examples/iris/iris.cpp b/src/examples/iris/iris.cpp index 9878a1ff0..328a4dfae 100644 --- a/src/examples/iris/iris.cpp +++ b/src/examples/iris/iris.cpp @@ -79,7 +79,12 @@ int main() { auto graph = New(); // Set general options - graph->setDevice({0, DeviceType::gpu}); +#ifdef CUDA_FOUND + auto deviceType = DeviceType::gpu; +#else + auto deviceType = DeviceType::cpu; +#endif + graph->setDevice({0, deviceType}); graph->reserveWorkspaceMB(128); // Choose optimizer (Sgd, Adagrad, Adam) and initial learning rate From 34bc47cd3df7ae74013604abcbd4dea5017fe261 Mon Sep 17 00:00:00 2001 From: Roman Grundkiewicz Date: Sun, 12 Apr 2020 19:14:03 +0100 Subject: [PATCH 19/62] Dump version --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 57d503d1e..e3f63fa09 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v1.9.5 +v1.9.7 From bc8b6fa162b0840387e195cc3073680bbf854862 Mon Sep 17 00:00:00 2001 From: Martin Junczys-Dowmunt Date: Tue, 14 Apr 2020 00:28:44 +0000 Subject: [PATCH 20/62] Merged PR 12442: cherry pick a few improvements/fixes from Frank's branch Cherry pick a few improvements/fixes from Frank's branch * Adds Frank's fix for label-based mini-batch sizing from Frank's current experimental branch. * Also copies minor improvements and a few comments. --- src/layers/loss.h | 6 ++++++ src/optimizers/optimizers.cpp | 9 +++++++++ src/training/graph_group_sync.cpp | 10 ++++------ src/training/validator.h | 7 ++++++- 4 files changed, 25 insertions(+), 7 deletions(-) diff --git a/src/layers/loss.h b/src/layers/loss.h index 43e89c1d3..315eda388 100755 --- a/src/layers/loss.h +++ b/src/layers/loss.h @@ -93,6 +93,12 @@ struct StaticLoss { StaticLoss(const RationalLoss& dynamic) : loss(dynamic.loss()), count(dynamic.count()) {} + StaticLoss operator +(const StaticLoss& other) const { + StaticLoss res(*this); + res += other; + return res; + } + StaticLoss& operator +=(const StaticLoss& other) { loss = loss + other.loss; count = count + other.count; diff --git a/src/optimizers/optimizers.cpp b/src/optimizers/optimizers.cpp index 083f94c01..d1dbcf59e 100755 --- a/src/optimizers/optimizers.cpp +++ b/src/optimizers/optimizers.cpp @@ -139,6 +139,15 @@ void Adam::updateImpl(Tensor params, Tensor grads, size_t actualMBSize, size_t r double Tref = (double)refMBWords; // adjust for minibatch-size changes if Adam parameters are given a reference size (else do nothing) + // Why the T/Tref factor on eta? The Adam optimizer adds an RMS-normalized gradient + // value (times learning rate) to the model. We know that for Tref, that learning rate is good. + // If we increase the batch size by (T/Tref), then without adjustment, we would still add an + // RMS-normalized gradient value. That means that the contribution of an individual label is + // now weighted down by (T/Tref). However, batch-size agnostic hyper-parameterization aims to keep + // the weight on the contribution of each label gradient invariant. Thus, we must undo that + // down-weighting, by multiplying the RMS-normalized gradient value by an additional factor + // of (T/Tref). This is implemented here by locally multiplying the learning rate + // with that factor. double eta = eta_ * (T/Tref); double beta1 = beta1_; double beta2 = beta2_; diff --git a/src/training/graph_group_sync.cpp b/src/training/graph_group_sync.cpp index 1b3c16ded..1457faffc 100755 --- a/src/training/graph_group_sync.cpp +++ b/src/training/graph_group_sync.cpp @@ -192,7 +192,7 @@ bool SyncGraphGroup::tryGetSubBatches(Ptr newBatch, // If a reference is given, then at progress == mbWarmup.n (ratio=1), we would like to have refBatchLabels instead of whichever // the actual batch size is. Since we cannot know the future actual batch sizes that will be delivered // by the reader, we approximate them with (typicalTrgBatchWords * updateMultiplier), and scale ratio accordingly. - auto refBatchLabels = options_->get("mini-batch-words-ref"); + auto refBatchLabels = options_->get("mini-batch-words"); if (refBatchLabels != 0) { LOG_ONCE(info, "[scheduler] Scaling to {} reference labels, using actual-batch-word estimate of {}", refBatchLabels, typicalTrgBatchWords_); ABORT_IF(typicalTrgBatchWords_ == 0, "Dynamic scaling with words target requires MB size to be known in words"); // happens if MB size is specified in sentences @@ -338,7 +338,7 @@ void SyncGraphGroup::update(std::vector> subBatches, size_t num // actual model update auto updateTrgWords = /*if*/(options_->get("cost-type") == "ce-sum") ? - batchTrgWords + batchTrgWords // total number of labels across all GPUs and nodes /*else*/: OptimizerBase::mbSizeNotProvided; shardOpt_[idx]->update(curParam, curGrad, updateTrgWords); @@ -350,10 +350,8 @@ void SyncGraphGroup::update(std::vector> subBatches, size_t num }; // cost across all local devices (scheduler will aggregate cross-process) - StaticLoss localLoss; - for(auto& l : localDeviceLosses) // localDeviceLosses is already summed up over delay steps - localLoss += l; - + StaticLoss localLoss = std::accumulate(localDeviceLosses.begin(), localDeviceLosses.end(), StaticLoss()); + // model update if (std::isfinite(localLoss.loss) || mpi_->numMPIProcesses() > 1) { // guard against NaN (except with MPI, as this simple way could hang it) comm_->scatterReduceAndResetGrads(); // reduce gradients across all devices and MPI nodes into shards diff --git a/src/training/validator.h b/src/training/validator.h index 1658dff3f..afb359ba6 100755 --- a/src/training/validator.h +++ b/src/training/validator.h @@ -66,8 +66,13 @@ class Validator : public ValidatorBase { options_->set("max-length", options_->get("valid-max-length")); options_->set("max-length-crop", true); // @TODO: make this configureable } - if(options_->has("valid-mini-batch")) + + // @TODO: make this work with mini-batch-fit etc. + if(options_->has("valid-mini-batch")) { options_->set("mini-batch", options_->get("valid-mini-batch")); + options_->set("mini-batch-words", 0); + } + options_->set("mini-batch-sort", "src"); options_->set("maxi-batch", 10); } From ce94fe989243d7aa1d5f445e1181ea2fcbc7c7a2 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Mon, 13 Apr 2020 17:31:06 -0700 Subject: [PATCH 21/62] update changelog and version --- CHANGELOG.md | 1 + VERSION | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index df1b05c79..439c330ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. and translation with options --tsv and --tsv-fields n. ### Fixed +- Make mini-batch scaling depend on mini-batch-words and not on mini-batch-words-ref - In concatenation make sure that we do not multiply 0 with nan (which results in nan) - Change Approx.epsilon(0.01) to Approx.margin(0.001) in unit tests. Tolerance is now absolute and not relative. We assumed incorrectly that epsilon is absolute tolerance. diff --git a/VERSION b/VERSION index e3f63fa09..b0376728d 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v1.9.7 +v1.9.8 From 59dad14ed1a1657b4d1cda9756aa08ae2bea70e5 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Thu, 16 Apr 2020 11:15:42 +0100 Subject: [PATCH 22/62] python3 shebang from #620 (#621) * python3 shebang from #620 * Add changelog entry for python3 change --- CHANGELOG.md | 1 + scripts/checkpoints/average.py | 2 +- scripts/contrib/inject_ctt.py | 2 +- scripts/contrib/inject_model_params.py | 2 +- scripts/embeddings/prepare_corpus.py | 2 +- scripts/embeddings/process_word2vec.py | 2 +- scripts/server/client_example.py | 2 +- 7 files changed, 7 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 439c330ce..a34915636 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. absolute and not relative. We assumed incorrectly that epsilon is absolute tolerance. ### Changed +- Python scripts start with #!/usr/bin/env python3 instead of python - Changed compile flags -Ofast to -O3 and remove --ffinite-math - Moved old graph groups to depracated folder - Make cublas and cusparse handle inits lazy to save memory when unused diff --git a/scripts/checkpoints/average.py b/scripts/checkpoints/average.py index 53bff1862..da1ca2526 100755 --- a/scripts/checkpoints/average.py +++ b/scripts/checkpoints/average.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 """ This script takes multiple Marian *.npz model files and outputs an elementwise average of the model, meant to do check-point averaging from: diff --git a/scripts/contrib/inject_ctt.py b/scripts/contrib/inject_ctt.py index 751ee1c60..620c31526 100755 --- a/scripts/contrib/inject_ctt.py +++ b/scripts/contrib/inject_ctt.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 from __future__ import print_function diff --git a/scripts/contrib/inject_model_params.py b/scripts/contrib/inject_model_params.py index 46096eb8b..a0e637a39 100755 --- a/scripts/contrib/inject_model_params.py +++ b/scripts/contrib/inject_model_params.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 from __future__ import print_function diff --git a/scripts/embeddings/prepare_corpus.py b/scripts/embeddings/prepare_corpus.py index 98326218b..0c54be5d0 100755 --- a/scripts/embeddings/prepare_corpus.py +++ b/scripts/embeddings/prepare_corpus.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # -*- coding: utf-8 -*- from __future__ import print_function diff --git a/scripts/embeddings/process_word2vec.py b/scripts/embeddings/process_word2vec.py index 4f5ba493f..685f8d23d 100755 --- a/scripts/embeddings/process_word2vec.py +++ b/scripts/embeddings/process_word2vec.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # -*- coding: utf-8 -*- from __future__ import print_function diff --git a/scripts/server/client_example.py b/scripts/server/client_example.py index 7f9e0ae37..4c194d74f 100755 --- a/scripts/server/client_example.py +++ b/scripts/server/client_example.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 from __future__ import print_function, unicode_literals, division From 342db58b7f25430c21563d414180f8620123cb59 Mon Sep 17 00:00:00 2001 From: Roman Grundkiewicz Date: Sun, 26 Apr 2020 16:43:36 +0100 Subject: [PATCH 23/62] Update submodule regression-tests --- regression-tests | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/regression-tests b/regression-tests index 67281c736..d1db7ea10 160000 --- a/regression-tests +++ b/regression-tests @@ -1 +1 @@ -Subproject commit 67281c736fcffb074e35665fe6c52be9a4cf5ca8 +Subproject commit d1db7ea10071252fa669c034c9c99acf159c8920 From 3f7b459d18e5fdc12e44122bef9b8807ec0554ac Mon Sep 17 00:00:00 2001 From: Roman Grundkiewicz Date: Mon, 27 Apr 2020 10:34:10 +0100 Subject: [PATCH 24/62] Update Simple-WebSocket-Server and move it to submodules (#639) * Fix server build with current boost, move simple-websocket-server to submodule * Change submodule to marian-nmt/Simple-WebSocket-Server * Update submodule simple-websocket-server Co-authored-by: Gleb Tv --- .gitmodules | 3 + CHANGELOG.md | 2 + VERSION | 2 +- scripts/server/client_example.py | 1 + src/3rd_party/simple-websocket-server | 1 + .../simple-websocket-server/crypto.hpp | 251 ------ .../simple-websocket-server/server_ws.hpp | 823 ------------------ .../simple-websocket-server/status_code.hpp | 191 ---- .../simple-websocket-server/utility.hpp | 381 -------- src/command/marian_server.cpp | 4 +- 10 files changed, 10 insertions(+), 1649 deletions(-) create mode 160000 src/3rd_party/simple-websocket-server delete mode 100644 src/3rd_party/simple-websocket-server/crypto.hpp delete mode 100644 src/3rd_party/simple-websocket-server/server_ws.hpp delete mode 100644 src/3rd_party/simple-websocket-server/status_code.hpp delete mode 100644 src/3rd_party/simple-websocket-server/utility.hpp diff --git a/.gitmodules b/.gitmodules index b7c67befc..6cb63fc0b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -14,3 +14,6 @@ path = src/3rd_party/fbgemm url = https://github.com/marian-nmt/FBGEMM branch = master +[submodule "src/3rd_party/simple-websocket-server"] + path = src/3rd_party/simple-websocket-server + url = https://github.com/marian-nmt/Simple-WebSocket-Server diff --git a/CHANGELOG.md b/CHANGELOG.md index a34915636..715f83df3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,12 +15,14 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. and translation with options --tsv and --tsv-fields n. ### Fixed +- Fix building server with Boost 1.72 - Make mini-batch scaling depend on mini-batch-words and not on mini-batch-words-ref - In concatenation make sure that we do not multiply 0 with nan (which results in nan) - Change Approx.epsilon(0.01) to Approx.margin(0.001) in unit tests. Tolerance is now absolute and not relative. We assumed incorrectly that epsilon is absolute tolerance. ### Changed +- Move Simple-WebSocket-Server to submodule - Python scripts start with #!/usr/bin/env python3 instead of python - Changed compile flags -Ofast to -O3 and remove --ffinite-math - Moved old graph groups to depracated folder diff --git a/VERSION b/VERSION index b0376728d..dd63e963b 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v1.9.8 +v1.9.9 diff --git a/scripts/server/client_example.py b/scripts/server/client_example.py index 4c194d74f..a68ffc2ab 100755 --- a/scripts/server/client_example.py +++ b/scripts/server/client_example.py @@ -6,6 +6,7 @@ import time import argparse +# pip install websocket_client from websocket import create_connection diff --git a/src/3rd_party/simple-websocket-server b/src/3rd_party/simple-websocket-server new file mode 160000 index 000000000..417a2a9e9 --- /dev/null +++ b/src/3rd_party/simple-websocket-server @@ -0,0 +1 @@ +Subproject commit 417a2a9e9dbd720b8d2dfa1dafe57cf1b37ca0d7 diff --git a/src/3rd_party/simple-websocket-server/crypto.hpp b/src/3rd_party/simple-websocket-server/crypto.hpp deleted file mode 100644 index 0e3dbf34c..000000000 --- a/src/3rd_party/simple-websocket-server/crypto.hpp +++ /dev/null @@ -1,251 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2014-2017 Ole Christian Eidheim - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#ifndef SIMPLE_WEB_CRYPTO_HPP -#define SIMPLE_WEB_CRYPTO_HPP - -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace SimpleWeb { -// TODO 2017: remove workaround for MSVS 2012 -#if _MSC_VER == 1700 // MSVS 2012 has no definition for round() - inline double round(double x) noexcept { // Custom definition of round() for positive numbers - return floor(x + 0.5); - } -#endif - - class Crypto { - const static std::size_t buffer_size = 131072; - - public: - class Base64 { - public: - static std::string encode(const std::string &ascii) noexcept { - std::string base64; - - BIO *bio, *b64; - BUF_MEM *bptr = BUF_MEM_new(); - - b64 = BIO_new(BIO_f_base64()); - BIO_set_flags(b64, BIO_FLAGS_BASE64_NO_NL); - bio = BIO_new(BIO_s_mem()); - BIO_push(b64, bio); - BIO_set_mem_buf(b64, bptr, BIO_CLOSE); - - // Write directly to base64-buffer to avoid copy - auto base64_length = static_cast(round(4 * ceil(static_cast(ascii.size()) / 3.0))); - base64.resize(base64_length); - bptr->length = 0; - bptr->max = base64_length + 1; - bptr->data = &base64[0]; - - if(BIO_write(b64, &ascii[0], static_cast(ascii.size())) <= 0 || BIO_flush(b64) <= 0) - base64.clear(); - - // To keep &base64[0] through BIO_free_all(b64) - bptr->length = 0; - bptr->max = 0; - bptr->data = nullptr; - - BIO_free_all(b64); - - return base64; - } - - static std::string decode(const std::string &base64) noexcept { - std::string ascii; - - // Resize ascii, however, the size is a up to two bytes too large. - ascii.resize((6 * base64.size()) / 8); - BIO *b64, *bio; - - b64 = BIO_new(BIO_f_base64()); - BIO_set_flags(b64, BIO_FLAGS_BASE64_NO_NL); -// TODO: Remove in 2020 -#if OPENSSL_VERSION_NUMBER <= 0x1000115fL - bio = BIO_new_mem_buf((char *)&base64[0], static_cast(base64.size())); -#else - bio = BIO_new_mem_buf(&base64[0], static_cast(base64.size())); -#endif - bio = BIO_push(b64, bio); - - auto decoded_length = BIO_read(bio, &ascii[0], static_cast(ascii.size())); - if(decoded_length > 0) - ascii.resize(static_cast(decoded_length)); - else - ascii.clear(); - - BIO_free_all(b64); - - return ascii; - } - }; - - /// Return hex string from bytes in input string. - static std::string to_hex_string(const std::string &input) noexcept { - std::stringstream hex_stream; - hex_stream << std::hex << std::internal << std::setfill('0'); - for(auto &byte : input) - hex_stream << std::setw(2) << static_cast(static_cast(byte)); - return hex_stream.str(); - } - - static std::string md5(const std::string &input, std::size_t iterations = 1) noexcept { - std::string hash; - - hash.resize(128 / 8); - MD5(reinterpret_cast(&input[0]), input.size(), reinterpret_cast(&hash[0])); - - for(std::size_t c = 1; c < iterations; ++c) - MD5(reinterpret_cast(&hash[0]), hash.size(), reinterpret_cast(&hash[0])); - - return hash; - } - - static std::string md5(std::istream &stream, std::size_t iterations = 1) noexcept { - MD5_CTX context; - MD5_Init(&context); - std::streamsize read_length; - std::vector buffer(buffer_size); - while((read_length = stream.read(&buffer[0], buffer_size).gcount()) > 0) - MD5_Update(&context, buffer.data(), static_cast(read_length)); - std::string hash; - hash.resize(128 / 8); - MD5_Final(reinterpret_cast(&hash[0]), &context); - - for(std::size_t c = 1; c < iterations; ++c) - MD5(reinterpret_cast(&hash[0]), hash.size(), reinterpret_cast(&hash[0])); - - return hash; - } - - static std::string sha1(const std::string &input, std::size_t iterations = 1) noexcept { - std::string hash; - - hash.resize(160 / 8); - SHA1(reinterpret_cast(&input[0]), input.size(), reinterpret_cast(&hash[0])); - - for(std::size_t c = 1; c < iterations; ++c) - SHA1(reinterpret_cast(&hash[0]), hash.size(), reinterpret_cast(&hash[0])); - - return hash; - } - - static std::string sha1(std::istream &stream, std::size_t iterations = 1) noexcept { - SHA_CTX context; - SHA1_Init(&context); - std::streamsize read_length; - std::vector buffer(buffer_size); - while((read_length = stream.read(&buffer[0], buffer_size).gcount()) > 0) - SHA1_Update(&context, buffer.data(), static_cast(read_length)); - std::string hash; - hash.resize(160 / 8); - SHA1_Final(reinterpret_cast(&hash[0]), &context); - - for(std::size_t c = 1; c < iterations; ++c) - SHA1(reinterpret_cast(&hash[0]), hash.size(), reinterpret_cast(&hash[0])); - - return hash; - } - - static std::string sha256(const std::string &input, std::size_t iterations = 1) noexcept { - std::string hash; - - hash.resize(256 / 8); - SHA256(reinterpret_cast(&input[0]), input.size(), reinterpret_cast(&hash[0])); - - for(std::size_t c = 1; c < iterations; ++c) - SHA256(reinterpret_cast(&hash[0]), hash.size(), reinterpret_cast(&hash[0])); - - return hash; - } - - static std::string sha256(std::istream &stream, std::size_t iterations = 1) noexcept { - SHA256_CTX context; - SHA256_Init(&context); - std::streamsize read_length; - std::vector buffer(buffer_size); - while((read_length = stream.read(&buffer[0], buffer_size).gcount()) > 0) - SHA256_Update(&context, buffer.data(), static_cast(read_length)); - std::string hash; - hash.resize(256 / 8); - SHA256_Final(reinterpret_cast(&hash[0]), &context); - - for(std::size_t c = 1; c < iterations; ++c) - SHA256(reinterpret_cast(&hash[0]), hash.size(), reinterpret_cast(&hash[0])); - - return hash; - } - - static std::string sha512(const std::string &input, std::size_t iterations = 1) noexcept { - std::string hash; - - hash.resize(512 / 8); - SHA512(reinterpret_cast(&input[0]), input.size(), reinterpret_cast(&hash[0])); - - for(std::size_t c = 1; c < iterations; ++c) - SHA512(reinterpret_cast(&hash[0]), hash.size(), reinterpret_cast(&hash[0])); - - return hash; - } - - static std::string sha512(std::istream &stream, std::size_t iterations = 1) noexcept { - SHA512_CTX context; - SHA512_Init(&context); - std::streamsize read_length; - std::vector buffer(buffer_size); - while((read_length = stream.read(&buffer[0], buffer_size).gcount()) > 0) - SHA512_Update(&context, buffer.data(), static_cast(read_length)); - std::string hash; - hash.resize(512 / 8); - SHA512_Final(reinterpret_cast(&hash[0]), &context); - - for(std::size_t c = 1; c < iterations; ++c) - SHA512(reinterpret_cast(&hash[0]), hash.size(), reinterpret_cast(&hash[0])); - - return hash; - } - - /// key_size is number of bytes of the returned key. - static std::string pbkdf2(const std::string &password, const std::string &salt, int iterations, int key_size) noexcept { - std::string key; - key.resize(static_cast(key_size)); - PKCS5_PBKDF2_HMAC_SHA1(password.c_str(), (int)password.size(), - reinterpret_cast(salt.c_str()), (int)salt.size(), iterations, - key_size, reinterpret_cast(&key[0])); - return key; - } - }; -} -#endif /* SIMPLE_WEB_CRYPTO_HPP */ diff --git a/src/3rd_party/simple-websocket-server/server_ws.hpp b/src/3rd_party/simple-websocket-server/server_ws.hpp deleted file mode 100644 index 609a86189..000000000 --- a/src/3rd_party/simple-websocket-server/server_ws.hpp +++ /dev/null @@ -1,823 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2014-2017 Ole Christian Eidheim - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#ifndef SERVER_WS_HPP -#define SERVER_WS_HPP - -#include "crypto.hpp" -#include "utility.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#ifdef USE_STANDALONE_ASIO -#include -#include -namespace SimpleWeb { - using error_code = std::error_code; - using errc = std::errc; - namespace make_error_code = std; -} // namespace SimpleWeb -#else -#include -#include -namespace SimpleWeb { - namespace asio = boost::asio; - using error_code = boost::system::error_code; - namespace errc = boost::system::errc; - namespace make_error_code = boost::system::errc; -} // namespace SimpleWeb -#endif - -// Late 2017 TODO: remove the following checks and always use std::regex -#ifdef USE_BOOST_REGEX -#include -namespace SimpleWeb { - namespace regex = boost; -} -#else -#include -namespace SimpleWeb { - namespace regex = std; -} -#endif - -namespace SimpleWeb { - template - class SocketServer; - - template - class SocketServerBase { - public: - class Message : public std::istream { - friend class SocketServerBase; - - public: - unsigned char fin_rsv_opcode; - std::size_t size() noexcept { - return length; - } - - /// Convenience function to return std::string. The stream buffer is consumed. - std::string string() noexcept { - try { - std::stringstream ss; - ss << rdbuf(); - return ss.str(); - } - catch(...) { - return std::string(); - } - } - - private: - Message() noexcept : std::istream(&streambuf), length(0) {} - Message(unsigned char fin_rsv_opcode, std::size_t length) noexcept : std::istream(&streambuf), fin_rsv_opcode(fin_rsv_opcode), length(length) {} - std::size_t length; - asio::streambuf streambuf; - }; - - /// The buffer is not consumed during send operations. - /// Do not alter while sending. - class SendStream : public std::ostream { - friend class SocketServerBase; - - asio::streambuf streambuf; - - public: - SendStream() noexcept : std::ostream(&streambuf) {} - - /// Returns the size of the buffer - std::size_t size() const noexcept { - return streambuf.size(); - } - }; - - class Connection : public std::enable_shared_from_this { - friend class SocketServerBase; - friend class SocketServer; - - public: - Connection(std::unique_ptr &&socket) noexcept : socket(std::move(socket)), timeout_idle(0), strand(this->socket->get_io_service()), closed(false) {} - - std::string method, path, query_string, http_version; - - CaseInsensitiveMultimap header; - - regex::smatch path_match; - - asio::ip::tcp::endpoint remote_endpoint; - - std::string remote_endpoint_address() noexcept { - try { - return remote_endpoint.address().to_string(); - } - catch(...) { - return std::string(); - } - } - - unsigned short remote_endpoint_port() noexcept { - return remote_endpoint.port(); - } - - private: - template - Connection(std::shared_ptr handler_runner, long timeout_idle, Args &&... args) noexcept - : handler_runner(std::move(handler_runner)), socket(new socket_type(std::forward(args)...)), timeout_idle(timeout_idle), strand(socket->get_io_service()), closed(false) {} - - std::shared_ptr handler_runner; - - std::unique_ptr socket; // Socket must be unique_ptr since asio::ssl::stream is not movable - std::mutex socket_close_mutex; - - asio::streambuf read_buffer; - std::shared_ptr fragmented_message; - - long timeout_idle; - std::unique_ptr timer; - std::mutex timer_mutex; - - void close() noexcept { - error_code ec; - std::unique_lock lock(socket_close_mutex); // The following operations seems to be needed to run sequentially - socket->lowest_layer().shutdown(asio::ip::tcp::socket::shutdown_both, ec); - socket->lowest_layer().close(ec); - } - - void set_timeout(long seconds = -1) noexcept { - bool use_timeout_idle = false; - if(seconds == -1) { - use_timeout_idle = true; - seconds = timeout_idle; - } - - std::unique_lock lock(timer_mutex); - - if(seconds == 0) { - timer = nullptr; - return; - } - - timer = std::unique_ptr(new asio::steady_timer(socket->get_io_service())); - timer->expires_from_now(std::chrono::seconds(seconds)); - std::weak_ptr connection_weak(this->shared_from_this()); // To avoid keeping Connection instance alive longer than needed - timer->async_wait([connection_weak, use_timeout_idle](const error_code &ec) { - if(!ec) { - if(auto connection = connection_weak.lock()) { - if(use_timeout_idle) - connection->send_close(1000, "idle timeout"); // 1000=normal closure - else - connection->close(); - } - } - }); - } - - void cancel_timeout() noexcept { - std::unique_lock lock(timer_mutex); - if(timer) { - error_code ec; - timer->cancel(ec); - } - } - - bool generate_handshake(const std::shared_ptr &write_buffer) { - std::ostream handshake(write_buffer.get()); - - auto header_it = header.find("Sec-WebSocket-Key"); - if(header_it == header.end()) - return false; - - static auto ws_magic_string = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; - auto sha1 = Crypto::sha1(header_it->second + ws_magic_string); - - handshake << "HTTP/1.1 101 Web Socket Protocol Handshake\r\n"; - handshake << "Upgrade: websocket\r\n"; - handshake << "Connection: Upgrade\r\n"; - handshake << "Sec-WebSocket-Accept: " << Crypto::Base64::encode(sha1) << "\r\n"; - handshake << "\r\n"; - - return true; - } - - asio::io_service::strand strand; - - class SendData { - public: - SendData(std::shared_ptr header_stream, std::shared_ptr message_stream, - std::function &&callback) noexcept - : header_stream(std::move(header_stream)), message_stream(std::move(message_stream)), callback(std::move(callback)) {} - std::shared_ptr header_stream; - std::shared_ptr message_stream; - std::function callback; - }; - - std::list send_queue; - - void send_from_queue() { - auto self = this->shared_from_this(); - strand.post([self]() { - asio::async_write(*self->socket, self->send_queue.begin()->header_stream->streambuf, self->strand.wrap([self](const error_code &ec, std::size_t /*bytes_transferred*/) { - auto lock = self->handler_runner->continue_lock(); - if(!lock) - return; - if(!ec) { - asio::async_write(*self->socket, self->send_queue.begin()->message_stream->streambuf.data(), self->strand.wrap([self](const error_code &ec, std::size_t /*bytes_transferred*/) { - auto lock = self->handler_runner->continue_lock(); - if(!lock) - return; - auto send_queued = self->send_queue.begin(); - if(send_queued->callback) - send_queued->callback(ec); - if(!ec) { - self->send_queue.erase(send_queued); - if(self->send_queue.size() > 0) - self->send_from_queue(); - } - else - self->send_queue.clear(); - })); - } - else { - auto send_queued = self->send_queue.begin(); - if(send_queued->callback) - send_queued->callback(ec); - self->send_queue.clear(); - } - })); - }); - } - - std::atomic closed; - - void read_remote_endpoint() noexcept { - try { - remote_endpoint = socket->lowest_layer().remote_endpoint(); - } - catch(...) { - } - } - - public: - /// fin_rsv_opcode: 129=one fragment, text, 130=one fragment, binary, 136=close connection. - /// See http://tools.ietf.org/html/rfc6455#section-5.2 for more information - void send(const std::shared_ptr &send_stream, const std::function &callback = nullptr, - unsigned char fin_rsv_opcode = 129) { - cancel_timeout(); - set_timeout(); - - auto header_stream = std::make_shared(); - - std::size_t length = send_stream->size(); - - header_stream->put(static_cast(fin_rsv_opcode)); - // Unmasked (first length byte<128) - if(length >= 126) { - std::size_t num_bytes; - if(length > 0xffff) { - num_bytes = 8; - header_stream->put(127); - } - else { - num_bytes = 2; - header_stream->put(126); - } - - for(std::size_t c = num_bytes - 1; c != static_cast(-1); c--) - header_stream->put((static_cast(length) >> (8 * c)) % 256); - } - else - header_stream->put(static_cast(length)); - - auto self = this->shared_from_this(); - strand.post([self, header_stream, send_stream, callback]() { - self->send_queue.emplace_back(header_stream, send_stream, callback); - if(self->send_queue.size() == 1) - self->send_from_queue(); - }); - } - - void send_close(int status, const std::string &reason = "", const std::function &callback = nullptr) { - // Send close only once (in case close is initiated by server) - if(closed) - return; - closed = true; - - auto send_stream = std::make_shared(); - - send_stream->put((unsigned char)(status >> 8)); - send_stream->put((unsigned char)(status % 256)); - - *send_stream << reason; - - // fin_rsv_opcode=136: message close - send(send_stream, callback, 136); - } - }; - - class Endpoint { - friend class SocketServerBase; - - private: - std::unordered_set> connections; - std::mutex connections_mutex; - - public: - std::function)> on_open; - std::function, std::shared_ptr)> on_message; - std::function, int, const std::string &)> on_close; - std::function, const error_code &)> on_error; - std::function)> on_ping; - std::function)> on_pong; - - std::unordered_set> get_connections() noexcept { - std::unique_lock lock(connections_mutex); - auto copy = connections; - return copy; - } - }; - - class Config { - friend class SocketServerBase; - - private: - Config(unsigned short port) noexcept : port(port) {} - - public: - /// Port number to use. Defaults to 80 for HTTP and 443 for HTTPS. - unsigned short port; - /// If io_service is not set, number of threads that the server will use when start() is called. - /// Defaults to 1 thread. - std::size_t thread_pool_size = 1; - /// Timeout on request handling. Defaults to 5 seconds. - long timeout_request = 5; - /// Idle timeout. Defaults to no timeout. - long timeout_idle = 0; - /// Maximum size of incoming messages. Defaults to architecture maximum. - /// Exceeding this limit will result in a message_size error code and the connection will be closed. - std::size_t max_message_size = std::numeric_limits::max(); - /// IPv4 address in dotted decimal form or IPv6 address in hexadecimal notation. - /// If empty, the address will be any address. - std::string address; - /// Set to false to avoid binding the socket to an address that is already in use. Defaults to true. - bool reuse_address = true; - }; - /// Set before calling start(). - Config config; - - private: - class regex_orderable : public regex::regex { - std::string str; - - public: - regex_orderable(const char *regex_cstr) : regex::regex(regex_cstr), str(regex_cstr) {} - regex_orderable(const std::string ®ex_str) : regex::regex(regex_str), str(regex_str) {} - bool operator<(const regex_orderable &rhs) const noexcept { - return str < rhs.str; - } - }; - - public: - /// Warning: do not add or remove endpoints after start() is called - std::map endpoint; - - virtual void start() { - if(!io_service) { - io_service = std::make_shared(); - internal_io_service = true; - } - - if(io_service->stopped()) - io_service->reset(); - - asio::ip::tcp::endpoint ep; - if(config.address.size() > 0) - ep = asio::ip::tcp::endpoint(asio::ip::address::from_string(config.address), config.port); - else - ep = asio::ip::tcp::endpoint(asio::ip::tcp::v4(), config.port); - - if(!acceptor) - acceptor = std::unique_ptr(new asio::ip::tcp::acceptor(*io_service)); - acceptor->open(ep.protocol()); - acceptor->set_option(asio::socket_base::reuse_address(config.reuse_address)); - acceptor->bind(ep); - acceptor->listen(); - - accept(); - - if(internal_io_service) { - // If thread_pool_size>1, start m_io_service.run() in (thread_pool_size-1) threads for thread-pooling - threads.clear(); - for(std::size_t c = 1; c < config.thread_pool_size; c++) { - threads.emplace_back([this]() { - io_service->run(); - }); - } - // Main thread - if(config.thread_pool_size > 0) - io_service->run(); - - // Wait for the rest of the threads, if any, to finish as well - for(auto &t : threads) - t.join(); - } - } - - void stop() noexcept { - if(acceptor) { - error_code ec; - acceptor->close(ec); - - for(auto &pair : endpoint) { - std::unique_lock lock(pair.second.connections_mutex); - for(auto &connection : pair.second.connections) - connection->close(); - pair.second.connections.clear(); - } - - if(internal_io_service) - io_service->stop(); - } - } - - virtual ~SocketServerBase() noexcept {} - - std::unordered_set> get_connections() noexcept { - std::unordered_set> all_connections; - for(auto &e : endpoint) { - std::unique_lock lock(e.second.connections_mutex); - all_connections.insert(e.second.connections.begin(), e.second.connections.end()); - } - return all_connections; - } - - /** - * Upgrades a request, from for instance Simple-Web-Server, to a WebSocket connection. - * The parameters are moved to the Connection object. - * See also Server::on_upgrade in the Simple-Web-Server project. - * The socket's io_service is used, thus running start() is not needed. - * - * Example use: - * server.on_upgrade=[&socket_server] (auto socket, auto request) { - * auto connection=std::make_shared::Connection>(std::move(socket)); - * connection->method=std::move(request->method); - * connection->path=std::move(request->path); - * connection->query_string=std::move(request->query_string); - * connection->http_version=std::move(request->http_version); - * connection->header=std::move(request->header); - * connection->remote_endpoint=std::move(*request->remote_endpoint); - * socket_server.upgrade(connection); - * } - */ - void upgrade(const std::shared_ptr &connection) { - connection->handler_runner = handler_runner; - connection->timeout_idle = config.timeout_idle; - write_handshake(connection); - } - - /// If you have your own asio::io_service, store its pointer here before running start(). - std::shared_ptr io_service; - - protected: - bool internal_io_service = false; - - std::unique_ptr acceptor; - std::vector threads; - - std::shared_ptr handler_runner; - - SocketServerBase(unsigned short port) noexcept : config(port), handler_runner(new ScopeRunner()) {} - - virtual void accept() = 0; - - void read_handshake(const std::shared_ptr &connection) { - connection->read_remote_endpoint(); - - connection->set_timeout(config.timeout_request); - asio::async_read_until(*connection->socket, connection->read_buffer, "\r\n\r\n", [this, connection](const error_code &ec, std::size_t /*bytes_transferred*/) { - connection->cancel_timeout(); - auto lock = connection->handler_runner->continue_lock(); - if(!lock) - return; - if(!ec) { - std::istream stream(&connection->read_buffer); - if(RequestMessage::parse(stream, connection->method, connection->path, connection->query_string, connection->http_version, connection->header)) - write_handshake(connection); - } - }); - } - - void write_handshake(const std::shared_ptr &connection) { - for(auto ®ex_endpoint : endpoint) { - regex::smatch path_match; - if(regex::regex_match(connection->path, path_match, regex_endpoint.first)) { - auto write_buffer = std::make_shared(); - - if(connection->generate_handshake(write_buffer)) { - connection->path_match = std::move(path_match); - connection->set_timeout(config.timeout_request); - asio::async_write(*connection->socket, *write_buffer, [this, connection, write_buffer, ®ex_endpoint](const error_code &ec, std::size_t /*bytes_transferred*/) { - connection->cancel_timeout(); - auto lock = connection->handler_runner->continue_lock(); - if(!lock) - return; - if(!ec) { - connection_open(connection, regex_endpoint.second); - read_message(connection, regex_endpoint.second); - } - else - connection_error(connection, regex_endpoint.second, ec); - }); - } - return; - } - } - } - - void read_message(const std::shared_ptr &connection, Endpoint &ep) const { - asio::async_read(*connection->socket, connection->read_buffer, asio::transfer_exactly(2), [this, connection, &ep](const error_code &ec, std::size_t bytes_transferred) { - auto lock = connection->handler_runner->continue_lock(); - if(!lock) - return; - if(!ec) { - if(bytes_transferred == 0) { // TODO: why does this happen sometimes? - read_message(connection, ep); - return; - } - std::istream stream(&connection->read_buffer); - - std::array first_bytes; - stream.read((char *)&first_bytes[0], 2); - - unsigned char fin_rsv_opcode = first_bytes[0]; - - // Close connection if unmasked message from client (protocol error) - if(first_bytes[1] < 128) { - const std::string reason("message from client not masked"); - connection->send_close(1002, reason); - connection_close(connection, ep, 1002, reason); - return; - } - - std::size_t length = (first_bytes[1] & 127); - - if(length == 126) { - // 2 next bytes is the size of content - asio::async_read(*connection->socket, connection->read_buffer, asio::transfer_exactly(2), [this, connection, &ep, fin_rsv_opcode](const error_code &ec, std::size_t /*bytes_transferred*/) { - auto lock = connection->handler_runner->continue_lock(); - if(!lock) - return; - if(!ec) { - std::istream stream(&connection->read_buffer); - - std::array length_bytes; - stream.read((char *)&length_bytes[0], 2); - - std::size_t length = 0; - std::size_t num_bytes = 2; - for(std::size_t c = 0; c < num_bytes; c++) - length += static_cast(length_bytes[c]) << (8 * (num_bytes - 1 - c)); - - read_message_content(connection, length, ep, fin_rsv_opcode); - } - else - connection_error(connection, ep, ec); - }); - } - else if(length == 127) { - // 8 next bytes is the size of content - asio::async_read(*connection->socket, connection->read_buffer, asio::transfer_exactly(8), [this, connection, &ep, fin_rsv_opcode](const error_code &ec, std::size_t /*bytes_transferred*/) { - auto lock = connection->handler_runner->continue_lock(); - if(!lock) - return; - if(!ec) { - std::istream stream(&connection->read_buffer); - - std::array length_bytes; - stream.read((char *)&length_bytes[0], 8); - - std::size_t length = 0; - std::size_t num_bytes = 8; - for(std::size_t c = 0; c < num_bytes; c++) - length += static_cast(length_bytes[c]) << (8 * (num_bytes - 1 - c)); - - read_message_content(connection, length, ep, fin_rsv_opcode); - } - else - connection_error(connection, ep, ec); - }); - } - else - read_message_content(connection, length, ep, fin_rsv_opcode); - } - else - connection_error(connection, ep, ec); - }); - } - - void read_message_content(const std::shared_ptr &connection, std::size_t length, Endpoint &ep, unsigned char fin_rsv_opcode) const { - if(length + (connection->fragmented_message ? connection->fragmented_message->length : 0) > config.max_message_size) { - connection_error(connection, ep, make_error_code::make_error_code(errc::message_size)); - const int status = 1009; - const std::string reason = "message too big"; - connection->send_close(status, reason); - connection_close(connection, ep, status, reason); - return; - } - asio::async_read(*connection->socket, connection->read_buffer, asio::transfer_exactly(4 + length), [this, connection, length, &ep, fin_rsv_opcode](const error_code &ec, std::size_t /*bytes_transferred*/) { - auto lock = connection->handler_runner->continue_lock(); - if(!lock) - return; - if(!ec) { - std::istream istream(&connection->read_buffer); - - // Read mask - std::array mask; - istream.read((char *)&mask[0], 4); - - std::shared_ptr message; - - // If fragmented message - if((fin_rsv_opcode & 0x80) == 0 || (fin_rsv_opcode & 0x0f) == 0) { - if(!connection->fragmented_message) { - connection->fragmented_message = std::shared_ptr(new Message(fin_rsv_opcode, length)); - connection->fragmented_message->fin_rsv_opcode |= 0x80; - } - else - connection->fragmented_message->length += length; - message = connection->fragmented_message; - } - else - message = std::shared_ptr(new Message(fin_rsv_opcode, length)); - std::ostream ostream(&message->streambuf); - for(std::size_t c = 0; c < length; c++) - ostream.put((unsigned char)(istream.get() ^ mask[c % 4])); - - // If connection close - if((fin_rsv_opcode & 0x0f) == 8) { - connection->cancel_timeout(); - connection->set_timeout(); - - int status = 0; - if(length >= 2) { - unsigned char byte1 = (unsigned char)(message->get()); - unsigned char byte2 = (unsigned char)(message->get()); - status = (static_cast(byte1) << 8) + byte2; - } - - auto reason = message->string(); - connection->send_close(status, reason); - this->connection_close(connection, ep, status, reason); - } - // If ping - else if((fin_rsv_opcode & 0x0f) == 9) { - connection->cancel_timeout(); - connection->set_timeout(); - - // Send pong - auto empty_send_stream = std::make_shared(); - connection->send(empty_send_stream, nullptr, fin_rsv_opcode + 1); - - if(ep.on_ping) - ep.on_ping(connection); - - // Next message - this->read_message(connection, ep); - } - // If pong - else if((fin_rsv_opcode & 0x0f) == 10) { - connection->cancel_timeout(); - connection->set_timeout(); - - if(ep.on_pong) - ep.on_pong(connection); - - // Next message - this->read_message(connection, ep); - } - // If fragmented message and not final fragment - else if((fin_rsv_opcode & 0x80) == 0) { - // Next message - this->read_message(connection, ep); - } - else { - connection->cancel_timeout(); - connection->set_timeout(); - - if(ep.on_message) - ep.on_message(connection, message); - - // Next message - // Only reset fragmented_message for non-control frames (control frames can be in between a fragmented message) - connection->fragmented_message = nullptr; - this->read_message(connection, ep); - } - } - else - this->connection_error(connection, ep, ec); - }); - } - - void connection_open(const std::shared_ptr &connection, Endpoint &ep) const { - connection->cancel_timeout(); - connection->set_timeout(); - - { - std::unique_lock lock(ep.connections_mutex); - ep.connections.insert(connection); - } - - if(ep.on_open) - ep.on_open(connection); - } - - void connection_close(const std::shared_ptr &connection, Endpoint &ep, int status, const std::string &reason) const { - connection->cancel_timeout(); - connection->set_timeout(); - - { - std::unique_lock lock(ep.connections_mutex); - ep.connections.erase(connection); - } - - if(ep.on_close) - ep.on_close(connection, status, reason); - } - - void connection_error(const std::shared_ptr &connection, Endpoint &ep, const error_code &ec) const { - connection->cancel_timeout(); - connection->set_timeout(); - - { - std::unique_lock lock(ep.connections_mutex); - ep.connections.erase(connection); - } - - if(ep.on_error) - ep.on_error(connection, ec); - } - }; - - template - class SocketServer : public SocketServerBase {}; - - using WS = asio::ip::tcp::socket; - - template <> - class SocketServer : public SocketServerBase { - public: - SocketServer() noexcept : SocketServerBase(80) {} - - protected: - void accept() override { - std::shared_ptr connection(new Connection(handler_runner, config.timeout_idle, *io_service)); - - acceptor->async_accept(*connection->socket, [this, connection](const error_code &ec) { - auto lock = connection->handler_runner->continue_lock(); - if(!lock) - return; - // Immediately start accepting a new connection (if io_service hasn't been stopped) - if(ec != asio::error::operation_aborted) - accept(); - - if(!ec) { - asio::ip::tcp::no_delay option(true); - connection->socket->set_option(option); - - read_handshake(connection); - } - }); - } - }; -} // namespace SimpleWeb - -#endif /* SERVER_WS_HPP */ diff --git a/src/3rd_party/simple-websocket-server/status_code.hpp b/src/3rd_party/simple-websocket-server/status_code.hpp deleted file mode 100644 index 81bc5c8db..000000000 --- a/src/3rd_party/simple-websocket-server/status_code.hpp +++ /dev/null @@ -1,191 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2014-2017 Ole Christian Eidheim - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#ifndef SIMPLE_WEB_STATUS_CODE_HPP -#define SIMPLE_WEB_STATUS_CODE_HPP - -#include -#include -#include -#include - -namespace SimpleWeb { - enum class StatusCode { - unknown = 0, - information_continue = 100, - information_switching_protocols, - information_processing, - success_ok = 200, - success_created, - success_accepted, - success_non_authoritative_information, - success_no_content, - success_reset_content, - success_partial_content, - success_multi_status, - success_already_reported, - success_im_used = 226, - redirection_multiple_choices = 300, - redirection_moved_permanently, - redirection_found, - redirection_see_other, - redirection_not_modified, - redirection_use_proxy, - redirection_switch_proxy, - redirection_temporary_redirect, - redirection_permanent_redirect, - client_error_bad_request = 400, - client_error_unauthorized, - client_error_payment_required, - client_error_forbidden, - client_error_not_found, - client_error_method_not_allowed, - client_error_not_acceptable, - client_error_proxy_authentication_required, - client_error_request_timeout, - client_error_conflict, - client_error_gone, - client_error_length_required, - client_error_precondition_failed, - client_error_payload_too_large, - client_error_uri_too_long, - client_error_unsupported_media_type, - client_error_range_not_satisfiable, - client_error_expectation_failed, - client_error_im_a_teapot, - client_error_misdirection_required = 421, - client_error_unprocessable_entity, - client_error_locked, - client_error_failed_dependency, - client_error_upgrade_required = 426, - client_error_precondition_required = 428, - client_error_too_many_requests, - client_error_request_header_fields_too_large = 431, - client_error_unavailable_for_legal_reasons = 451, - server_error_internal_server_error = 500, - server_error_not_implemented, - server_error_bad_gateway, - server_error_service_unavailable, - server_error_gateway_timeout, - server_error_http_version_not_supported, - server_error_variant_also_negotiates, - server_error_insufficient_storage, - server_error_loop_detected, - server_error_not_extended = 510, - server_error_network_authentication_required - }; - - inline const std::map &status_code_strings() { - static const std::map status_code_strings = { - {StatusCode::unknown, ""}, - {StatusCode::information_continue, "100 Continue"}, - {StatusCode::information_switching_protocols, "101 Switching Protocols"}, - {StatusCode::information_processing, "102 Processing"}, - {StatusCode::success_ok, "200 OK"}, - {StatusCode::success_created, "201 Created"}, - {StatusCode::success_accepted, "202 Accepted"}, - {StatusCode::success_non_authoritative_information, "203 Non-Authoritative Information"}, - {StatusCode::success_no_content, "204 No Content"}, - {StatusCode::success_reset_content, "205 Reset Content"}, - {StatusCode::success_partial_content, "206 Partial Content"}, - {StatusCode::success_multi_status, "207 Multi-Status"}, - {StatusCode::success_already_reported, "208 Already Reported"}, - {StatusCode::success_im_used, "226 IM Used"}, - {StatusCode::redirection_multiple_choices, "300 Multiple Choices"}, - {StatusCode::redirection_moved_permanently, "301 Moved Permanently"}, - {StatusCode::redirection_found, "302 Found"}, - {StatusCode::redirection_see_other, "303 See Other"}, - {StatusCode::redirection_not_modified, "304 Not Modified"}, - {StatusCode::redirection_use_proxy, "305 Use Proxy"}, - {StatusCode::redirection_switch_proxy, "306 Switch Proxy"}, - {StatusCode::redirection_temporary_redirect, "307 Temporary Redirect"}, - {StatusCode::redirection_permanent_redirect, "308 Permanent Redirect"}, - {StatusCode::client_error_bad_request, "400 Bad Request"}, - {StatusCode::client_error_unauthorized, "401 Unauthorized"}, - {StatusCode::client_error_payment_required, "402 Payment Required"}, - {StatusCode::client_error_forbidden, "403 Forbidden"}, - {StatusCode::client_error_not_found, "404 Not Found"}, - {StatusCode::client_error_method_not_allowed, "405 Method Not Allowed"}, - {StatusCode::client_error_not_acceptable, "406 Not Acceptable"}, - {StatusCode::client_error_proxy_authentication_required, "407 Proxy Authentication Required"}, - {StatusCode::client_error_request_timeout, "408 Request Timeout"}, - {StatusCode::client_error_conflict, "409 Conflict"}, - {StatusCode::client_error_gone, "410 Gone"}, - {StatusCode::client_error_length_required, "411 Length Required"}, - {StatusCode::client_error_precondition_failed, "412 Precondition Failed"}, - {StatusCode::client_error_payload_too_large, "413 Payload Too Large"}, - {StatusCode::client_error_uri_too_long, "414 URI Too Long"}, - {StatusCode::client_error_unsupported_media_type, "415 Unsupported Media Type"}, - {StatusCode::client_error_range_not_satisfiable, "416 Range Not Satisfiable"}, - {StatusCode::client_error_expectation_failed, "417 Expectation Failed"}, - {StatusCode::client_error_im_a_teapot, "418 I'm a teapot"}, - {StatusCode::client_error_misdirection_required, "421 Misdirected Request"}, - {StatusCode::client_error_unprocessable_entity, "422 Unprocessable Entity"}, - {StatusCode::client_error_locked, "423 Locked"}, - {StatusCode::client_error_failed_dependency, "424 Failed Dependency"}, - {StatusCode::client_error_upgrade_required, "426 Upgrade Required"}, - {StatusCode::client_error_precondition_required, "428 Precondition Required"}, - {StatusCode::client_error_too_many_requests, "429 Too Many Requests"}, - {StatusCode::client_error_request_header_fields_too_large, "431 Request Header Fields Too Large"}, - {StatusCode::client_error_unavailable_for_legal_reasons, "451 Unavailable For Legal Reasons"}, - {StatusCode::server_error_internal_server_error, "500 Internal Server Error"}, - {StatusCode::server_error_not_implemented, "501 Not Implemented"}, - {StatusCode::server_error_bad_gateway, "502 Bad Gateway"}, - {StatusCode::server_error_service_unavailable, "503 Service Unavailable"}, - {StatusCode::server_error_gateway_timeout, "504 Gateway Timeout"}, - {StatusCode::server_error_http_version_not_supported, "505 HTTP Version Not Supported"}, - {StatusCode::server_error_variant_also_negotiates, "506 Variant Also Negotiates"}, - {StatusCode::server_error_insufficient_storage, "507 Insufficient Storage"}, - {StatusCode::server_error_loop_detected, "508 Loop Detected"}, - {StatusCode::server_error_not_extended, "510 Not Extended"}, - {StatusCode::server_error_network_authentication_required, "511 Network Authentication Required"}}; - return status_code_strings; - } - - inline StatusCode status_code(const std::string &status_code_string) noexcept { - class StringToStatusCode : public std::unordered_map { - public: - StringToStatusCode() { - for(auto &status_code : status_code_strings()) - emplace(status_code.second, status_code.first); - } - }; - static StringToStatusCode string_to_status_code; - auto pos = string_to_status_code.find(status_code_string); - if(pos == string_to_status_code.end()) - return StatusCode::unknown; - return pos->second; - } - - inline const std::string &status_code(StatusCode status_code_enum) noexcept { - auto pos = status_code_strings().find(status_code_enum); - if(pos == status_code_strings().end()) { - static std::string empty_string; - return empty_string; - } - return pos->second; - } -} // namespace SimpleWeb - -#endif // SIMPLE_WEB_STATUS_CODE_HPP diff --git a/src/3rd_party/simple-websocket-server/utility.hpp b/src/3rd_party/simple-websocket-server/utility.hpp deleted file mode 100644 index d2abcf53e..000000000 --- a/src/3rd_party/simple-websocket-server/utility.hpp +++ /dev/null @@ -1,381 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2014-2017 Ole Christian Eidheim - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#ifndef SIMPLE_WEB_UTILITY_HPP -#define SIMPLE_WEB_UTILITY_HPP - -#include "status_code.hpp" -#include -#include -#include -#include -#include - -namespace SimpleWeb { - inline bool case_insensitive_equal(const std::string &str1, const std::string &str2) noexcept { - return str1.size() == str2.size() && - std::equal(str1.begin(), str1.end(), str2.begin(), [](char a, char b) { - return tolower(a) == tolower(b); - }); - } - class CaseInsensitiveEqual { - public: - bool operator()(const std::string &str1, const std::string &str2) const noexcept { - return case_insensitive_equal(str1, str2); - } - }; - // Based on https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x/2595226#2595226 - class CaseInsensitiveHash { - public: - std::size_t operator()(const std::string &str) const noexcept { - std::size_t h = 0; - std::hash hash; - for(auto c : str) - h ^= hash(tolower(c)) + 0x9e3779b9 + (h << 6) + (h >> 2); - return h; - } - }; - - using CaseInsensitiveMultimap = std::unordered_multimap; - - /// Percent encoding and decoding - class Percent { - public: - /// Returns percent-encoded string - static std::string encode(const std::string &value) noexcept { - static auto hex_chars = "0123456789ABCDEF"; - - std::string result; - result.reserve(value.size()); // Minimum size of result - - for(auto &chr : value) { - if(chr == ' ') - result += '+'; - else if(chr == '!' || chr == '#' || chr == '$' || (chr >= '&' && chr <= ',') || (chr >= '/' && chr <= ';') || chr == '=' || chr == '?' || chr == '@' || chr == '[' || chr == ']') - result += std::string("%") + hex_chars[chr >> 4] + hex_chars[chr & 15]; - else - result += chr; - } - - return result; - } - - /// Returns percent-decoded string - static std::string decode(const std::string &value) noexcept { - std::string result; - result.reserve(value.size() / 3 + (value.size() % 3)); // Minimum size of result - - for(std::size_t i = 0; i < value.size(); ++i) { - auto &chr = value[i]; - if(chr == '%' && i + 2 < value.size()) { - auto hex = value.substr(i + 1, 2); - auto decoded_chr = static_cast(std::strtol(hex.c_str(), nullptr, 16)); - result += decoded_chr; - i += 2; - } - else if(chr == '+') - result += ' '; - else - result += chr; - } - - return result; - } - }; - - /// Query string creation and parsing - class QueryString { - public: - /// Returns query string created from given field names and values - static std::string create(const CaseInsensitiveMultimap &fields) noexcept { - std::string result; - - bool first = true; - for(auto &field : fields) { - result += (!first ? "&" : "") + field.first + '=' + Percent::encode(field.second); - first = false; - } - - return result; - } - - /// Returns query keys with percent-decoded values. - static CaseInsensitiveMultimap parse(const std::string &query_string) noexcept { - CaseInsensitiveMultimap result; - - if(query_string.empty()) - return result; - - std::size_t name_pos = 0; - auto name_end_pos = std::string::npos; - auto value_pos = std::string::npos; - for(std::size_t c = 0; c < query_string.size(); ++c) { - if(query_string[c] == '&') { - auto name = query_string.substr(name_pos, (name_end_pos == std::string::npos ? c : name_end_pos) - name_pos); - if(!name.empty()) { - auto value = value_pos == std::string::npos ? std::string() : query_string.substr(value_pos, c - value_pos); - result.emplace(std::move(name), Percent::decode(value)); - } - name_pos = c + 1; - name_end_pos = std::string::npos; - value_pos = std::string::npos; - } - else if(query_string[c] == '=') { - name_end_pos = c; - value_pos = c + 1; - } - } - if(name_pos < query_string.size()) { - auto name = query_string.substr(name_pos, name_end_pos - name_pos); - if(!name.empty()) { - auto value = value_pos >= query_string.size() ? std::string() : query_string.substr(value_pos); - result.emplace(std::move(name), Percent::decode(value)); - } - } - - return result; - } - }; - - class HttpHeader { - public: - /// Parse header fields - static CaseInsensitiveMultimap parse(std::istream &stream) noexcept { - CaseInsensitiveMultimap result; - std::string line; - getline(stream, line); - std::size_t param_end; - while((param_end = line.find(':')) != std::string::npos) { - std::size_t value_start = param_end + 1; - if(value_start < line.size()) { - if(line[value_start] == ' ') - value_start++; - if(value_start < line.size()) - result.emplace(line.substr(0, param_end), line.substr(value_start, line.size() - value_start - 1)); - } - - getline(stream, line); - } - return result; - } - - class FieldValue { - public: - class SemicolonSeparatedAttributes { - public: - /// Parse Set-Cookie or Content-Disposition header field value. Attribute values are percent-decoded. - static CaseInsensitiveMultimap parse(const std::string &str) { - CaseInsensitiveMultimap result; - - std::size_t name_start_pos = std::string::npos; - std::size_t name_end_pos = std::string::npos; - std::size_t value_start_pos = std::string::npos; - for(std::size_t c = 0; c < str.size(); ++c) { - if(name_start_pos == std::string::npos) { - if(str[c] != ' ' && str[c] != ';') - name_start_pos = c; - } - else { - if(name_end_pos == std::string::npos) { - if(str[c] == ';') { - result.emplace(str.substr(name_start_pos, c - name_start_pos), std::string()); - name_start_pos = std::string::npos; - } - else if(str[c] == '=') - name_end_pos = c; - } - else { - if(value_start_pos == std::string::npos) { - if(str[c] == '"' && c + 1 < str.size()) - value_start_pos = c + 1; - else - value_start_pos = c; - } - else if(str[c] == '"' || str[c] == ';') { - result.emplace(str.substr(name_start_pos, name_end_pos - name_start_pos), Percent::decode(str.substr(value_start_pos, c - value_start_pos))); - name_start_pos = std::string::npos; - name_end_pos = std::string::npos; - value_start_pos = std::string::npos; - } - } - } - } - if(name_start_pos != std::string::npos) { - if(name_end_pos == std::string::npos) - result.emplace(str.substr(name_start_pos), std::string()); - else if(value_start_pos != std::string::npos) { - if(str.back() == '"') - result.emplace(str.substr(name_start_pos, name_end_pos - name_start_pos), Percent::decode(str.substr(value_start_pos, str.size() - 1))); - else - result.emplace(str.substr(name_start_pos, name_end_pos - name_start_pos), Percent::decode(str.substr(value_start_pos))); - } - } - - return result; - } - }; - }; - }; // namespace SimpleWeb - - class RequestMessage { - public: - /// Parse request line and header fields - static bool parse(std::istream &stream, std::string &method, std::string &path, std::string &query_string, std::string &version, CaseInsensitiveMultimap &header) noexcept { - header.clear(); - std::string line; - getline(stream, line); - std::size_t method_end; - if((method_end = line.find(' ')) != std::string::npos) { - method = line.substr(0, method_end); - - std::size_t query_start = std::string::npos; - std::size_t path_and_query_string_end = std::string::npos; - for(std::size_t i = method_end + 1; i < line.size(); ++i) { - if(line[i] == '?' && (i + 1) < line.size()) - query_start = i + 1; - else if(line[i] == ' ') { - path_and_query_string_end = i; - break; - } - } - if(path_and_query_string_end != std::string::npos) { - if(query_start != std::string::npos) { - path = line.substr(method_end + 1, query_start - method_end - 2); - query_string = line.substr(query_start, path_and_query_string_end - query_start); - } - else - path = line.substr(method_end + 1, path_and_query_string_end - method_end - 1); - - std::size_t protocol_end; - if((protocol_end = line.find('/', path_and_query_string_end + 1)) != std::string::npos) { - if(line.compare(path_and_query_string_end + 1, protocol_end - path_and_query_string_end - 1, "HTTP") != 0) - return false; - version = line.substr(protocol_end + 1, line.size() - protocol_end - 2); - } - else - return false; - - header = HttpHeader::parse(stream); - } - else - return false; - } - else - return false; - return true; - } - }; - - class ResponseMessage { - public: - /// Parse status line and header fields - static bool parse(std::istream &stream, std::string &version, std::string &status_code, CaseInsensitiveMultimap &header) noexcept { - header.clear(); - std::string line; - getline(stream, line); - std::size_t version_end = line.find(' '); - if(version_end != std::string::npos) { - if(5 < line.size()) - version = line.substr(5, version_end - 5); - else - return false; - if((version_end + 1) < line.size()) - status_code = line.substr(version_end + 1, line.size() - (version_end + 1) - 1); - else - return false; - - header = HttpHeader::parse(stream); - } - else - return false; - return true; - } - }; -} // namespace SimpleWeb - -#ifdef __SSE2__ -#include -namespace SimpleWeb { - inline void spin_loop_pause() noexcept { _mm_pause(); } -} // namespace SimpleWeb -// TODO: need verification that the following checks are correct: -#elif defined(_MSC_VER) && _MSC_VER >= 1800 && (defined(_M_X64) || defined(_M_IX86)) -#include -namespace SimpleWeb { - inline void spin_loop_pause() noexcept { _mm_pause(); } -} // namespace SimpleWeb -#else -namespace SimpleWeb { - inline void spin_loop_pause() noexcept {} -} // namespace SimpleWeb -#endif - -namespace SimpleWeb { - /// Makes it possible to for instance cancel Asio handlers without stopping asio::io_service - class ScopeRunner { - /// Scope count that is set to -1 if scopes are to be canceled - std::atomic count; - - public: - class SharedLock { - friend class ScopeRunner; - std::atomic &count; - SharedLock(std::atomic &count) noexcept : count(count) {} - SharedLock &operator=(const SharedLock &) = delete; - SharedLock(const SharedLock &) = delete; - - public: - ~SharedLock() noexcept { - count.fetch_sub(1); - } - }; - - ScopeRunner() noexcept : count(0) {} - - /// Returns nullptr if scope should be exited, or a shared lock otherwise - std::unique_ptr continue_lock() noexcept { - long expected = count; - while(expected >= 0 && !count.compare_exchange_weak(expected, expected + 1)) - spin_loop_pause(); - - if(expected < 0) - return nullptr; - else - return std::unique_ptr(new SharedLock(count)); - } - - /// Blocks until all shared locks are released, then prevents future shared locks - void stop() noexcept { - long expected = 0; - while(!count.compare_exchange_weak(expected, -1)) { - if(expected < 0) - return; - expected = 0; - spin_loop_pause(); - } - } - }; -} // namespace SimpleWeb - -#endif // SIMPLE_WEB_UTILITY_HPP diff --git a/src/command/marian_server.cpp b/src/command/marian_server.cpp index e4074bd1b..2c3649407 100644 --- a/src/command/marian_server.cpp +++ b/src/command/marian_server.cpp @@ -22,10 +22,10 @@ int main(int argc, char **argv) { auto &translate = server.endpoint["^/translate/?$"]; translate.on_message = [&task](Ptr connection, - Ptr message) { + Ptr message) { // Get input text auto inputText = message->string(); - auto sendStream = std::make_shared(); + auto sendStream = std::make_shared(); // Translate timer::Timer timer; From 9ae1951fe2bf312fb573b2b18b6f2167de242e0c Mon Sep 17 00:00:00 2001 From: Nikolay Bogoychev Date: Thu, 14 May 2020 15:55:27 +0100 Subject: [PATCH 25/62] Batched gemm (#633) * Use cblas_sgemm_batch when available * Merge with master, add comments and describe contribution --- CHANGELOG.md | 1 + src/tensors/cpu/prod.cpp | 61 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 715f83df3..65a1f9071 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] ### Added +- Use *cblas_sgemm_batch* in stead of a for loop of *cblas_sgemm* on CPU as the batched_gemm implementation - Supporting relative paths in shortlist and sqlite options - Training and scoring from STDIN - Support for reading from TSV files from STDIN and other sources during training diff --git a/src/tensors/cpu/prod.cpp b/src/tensors/cpu/prod.cpp index ac13ccee8..9bdedd545 100755 --- a/src/tensors/cpu/prod.cpp +++ b/src/tensors/cpu/prod.cpp @@ -134,6 +134,66 @@ void ProdBatched(marian::Tensor C, auto strideC = n * m; auto batchC = std::max(batchA, batchB); +#if MKL_FOUND + CBLAS_TRANSPOSE transA_forarr = CblasNoTrans; + CBLAS_TRANSPOSE transB_forarr = CblasNoTrans; + + if(transA) + transA_forarr = CblasTrans; + + if(transB) + transB_forarr = CblasTrans; + + /* cblas_sgemm_batch allows us to group all the small GEMMs that are done in a for loop with sgemm and compute + * them in only one MKL call. For the API documentation refer to + * https://software.intel.com/content/www/us/en/develop/documentation/mkl-developer-reference-c/top/blas-and-sparse-blas-routines/blas-like-extensions/cblas-gemm-batch.html + * The API supports dependencies, where you can specify one "group" of GEMMs to be computed after another. (This controlled by the group_count parameter). + * In our case, the operations are not dependent on one another so we hardcode one group. The rest of the arguments (with the exception of group_size) are + * the same as the ones that cblas_sgemm expects, with the difference that we are supposed to provide an array pointer (One element per group). + * Weirdly enough, we are required to to provide all of the integer arguments as the MKL_INT datatype + */ + + static const constexpr size_t group_count = 1; // We have one group + const std::vector transa_arr(group_count, transA_forarr); + const std::vector transb_arr(group_count, transB_forarr); + const std::vector m_arr(group_count, (MKL_INT)m); + const std::vector n_arr(group_count, (MKL_INT)n); + const std::vector k_arr(group_count, (MKL_INT)k); + const std::vector alpha_arr(group_count, alpha); + const std::vector beta_arr(group_count, beta); + const std::vector lda_arr(group_count, (MKL_INT)lda); + const std::vector ldb_arr(group_count, (MKL_INT)ldb); + const std::vector ldc_arr(group_count, (MKL_INT)ldc); + const std::vector group_size(group_count, (MKL_INT)batchC); // Group size specifies number of GEMM operations per group (Which is batchC) + + std::vector a_array(batchC, nullptr); + std::vector b_array(batchC, nullptr); + std::vector c_array(batchC, nullptr); + + // This loop initializes the array pointers in the same way as the for loop + // in the normal sgemm version a few lines below + for(size_t i = 0; i < batchC; ++i) { + a_array[i] = A->data() + (i % batchA) * strideA; + b_array[i] = B->data() + (i % batchB) * strideB; + c_array[i] = C->data() + i * strideC; + } + cblas_sgemm_batch (CblasRowMajor, + &transa_arr[0], + &transb_arr[0], + &m_arr[0], + &n_arr[0], + &k_arr[0], + &alpha_arr[0], + &a_array[0], + &lda_arr[0], + &b_array[0], + &ldb_arr[0], + &beta_arr[0], + &c_array[0], + &ldc_arr[0], + group_count, + &group_size[0]); +#else for(size_t i = 0; i < batchC; ++i) { sgemm(transA, transB, @@ -149,6 +209,7 @@ void ProdBatched(marian::Tensor C, C->data() + i * strideC, (int)ldc); } +#endif #else C; A; B; transA; transB; beta; scalar; ABORT("You need to compile with MKL in order to use the CPU version"); From 1603d2fe2a653fadd5342d009f3f71d550ee8a60 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Thu, 14 May 2020 08:00:41 -0700 Subject: [PATCH 26/62] update version --- CHANGELOG.md | 2 +- VERSION | 2 +- regression-tests | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 65a1f9071..bdf07d461 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] ### Added -- Use *cblas_sgemm_batch* in stead of a for loop of *cblas_sgemm* on CPU as the batched_gemm implementation +- Use *cblas_sgemm_batch* instead of a for loop of *cblas_sgemm* on CPU as the batched_gemm implementation - Supporting relative paths in shortlist and sqlite options - Training and scoring from STDIN - Support for reading from TSV files from STDIN and other sources during training diff --git a/VERSION b/VERSION index dd63e963b..32d68684f 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v1.9.9 +v1.9.10 diff --git a/regression-tests b/regression-tests index d1db7ea10..67281c736 160000 --- a/regression-tests +++ b/regression-tests @@ -1 +1 @@ -Subproject commit d1db7ea10071252fa669c034c9c99acf159c8920 +Subproject commit 67281c736fcffb074e35665fe6c52be9a4cf5ca8 From ae1dd47878e9de406a5823c8691e9fa4a56a495a Mon Sep 17 00:00:00 2001 From: Roman Grundkiewicz Date: Sun, 17 May 2020 11:34:18 +0100 Subject: [PATCH 27/62] Update submodule regression-tests --- regression-tests | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/regression-tests b/regression-tests index 67281c736..0f8cabf13 160000 --- a/regression-tests +++ b/regression-tests @@ -1 +1 @@ -Subproject commit 67281c736fcffb074e35665fe6c52be9a4cf5ca8 +Subproject commit 0f8cabf13ec362d50544d33490024e00c3a763be From 9cd162307456d66a235d4e314492444b5849b96d Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Mon, 6 Apr 2020 12:49:23 +0100 Subject: [PATCH 28/62] Bug fix: better handling of SIGTERM for graceful shutdown during training. Prior to this bug fix, BatchGenerator::fetchBatches, which runs in a separate thread, would ignore SIGTERM during training (training uses a custom signal handler for SIGTERM, which simply sets a global flag, to enable graceful shutdown (i.e., save models and current state of training before shutting down). The changes in this commit also facilitate custom handling of other signals in the future by providing a general singal handler for all signals with a signal number below 32 (setSignalFlag) and a generic flag checking function (getSignalFlag(sig)) for checking such flags. --- CHANGELOG.md | 2 ++ src/CMakeLists.txt | 2 +- src/command/marian_train.cpp | 3 ++- src/common/signal_handling.cpp | 21 +++++++++++++++++ src/common/signal_handling.h | 27 +++++++++++++++++++++ src/data/batch_generator.h | 23 +++++++++++++----- src/training/scheduler.cpp | 43 ---------------------------------- src/training/scheduler.h | 11 ++++----- src/training/training.h | 6 +++++ 9 files changed, 80 insertions(+), 58 deletions(-) create mode 100644 src/common/signal_handling.cpp create mode 100644 src/common/signal_handling.h delete mode 100644 src/training/scheduler.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ab28029c..bdf07d461 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Changed compile flags -Ofast to -O3 and remove --ffinite-math - Moved old graph groups to depracated folder - Make cublas and cusparse handle inits lazy to save memory when unused +- Improved handling for graceful shutdown upon receiving SIGTERM. + SIGTERM now also interrupts batch prefetching, which runs in a separate thread. ## [1.9.0] - 2020-03-10 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4e9dd39dc..b78431d6a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -25,6 +25,7 @@ add_library(marian STATIC common/filesystem.cpp common/file_stream.cpp common/file_utils.cpp + common/signal_handling.cpp common/types.cpp data/alignment.cpp @@ -91,7 +92,6 @@ add_library(marian STATIC training/graph_group_singleton.cpp training/validator.cpp training/communicator.cpp - training/scheduler.cpp # this is only compiled to catch build errors, but not linked microsoft/quicksand.cpp diff --git a/src/command/marian_train.cpp b/src/command/marian_train.cpp index d1978fab4..46bd05e84 100644 --- a/src/command/marian_train.cpp +++ b/src/command/marian_train.cpp @@ -1,6 +1,7 @@ #include #include "marian.h" +#include "common/signal_handling.h" #include "training/graph_group_async.h" #include "training/graph_group_singleton.h" #include "training/graph_group_sync.h" @@ -51,5 +52,5 @@ int main(int argc, char** argv) { // returns for timeout -s SIGTERM ...., because exiting after SIGTERM // is not technically a fatal error (which is what the 128+x convention usually // stands for). - return getSigtermFlag() ? (128 + SIGTERM) : 0; + return getSignalFlag(SIGTERM) ? (128 + SIGTERM) : 0; } diff --git a/src/common/signal_handling.cpp b/src/common/signal_handling.cpp new file mode 100644 index 000000000..a18d1e669 --- /dev/null +++ b/src/common/signal_handling.cpp @@ -0,0 +1,21 @@ +#include "common/logging.h" +#include "signal_handling.h" + +// We use signal() here instead of the usual strong recommendation for +// using sigaction, which apparently is not available for Windows (cf. +// https://stackoverflow.com/questions/231912/what-is-the-difference-between-sigaction-and-signal). + +namespace marian{ +volatile std::sig_atomic_t sigflags_{0}; + +bool getSignalFlag(const int sig) { + // sig_atomic_t has 32 bits. We don't accommodate signals beyond that. + ABORT_IF(sig >= 32, "Signal {} out of range (must be < 32).", sig); + return sigflags_ & (1< + +// SIGNAL HANDLING + +// The Marian signal handlers set global flags that thread can +// consider when a signal is received. This can be used for a graceful +// shutdown instead of a hard abandonment, e.g. after receiving +// SIGTERM during training. + +// When SIGTERM is received, the global (static member) flag sigterm_ +// (false by default) is set to true by signalHandler(). When sigterm_ +// is true, keepGoing() returns false, and the current state of +// training models is saved prior to exiting. This functionality is +// helpful when training on clusters with time limits on compute +// slots, e.g., on s clusters managed by slurm. Slurm can be asked to +// sending a (custom) warning signal to a process at a given point in +// time prior to the hard "time's up". +// +// Correspondingly, fetchBatches in the batch generator checks the flag +// frequently and quits after the overall process receives a SIGTERM. + + +namespace marian { +bool getSignalFlag(int sig); // return true if sig was received, false otherwise +void setSignalFlag(int sig); // set custom handler (set flag) for sig +} diff --git a/src/data/batch_generator.h b/src/data/batch_generator.h index f16a7a81c..1a26baa04 100644 --- a/src/data/batch_generator.h +++ b/src/data/batch_generator.h @@ -1,6 +1,7 @@ #pragma once #include "common/options.h" +#include "common/signal_handling.h" #include "data/batch_stats.h" #include "data/rng_engine.h" #include "training/training_state.h" @@ -132,8 +133,14 @@ class BatchGenerator : public RNGEngine { if(current_ != data_->end()) ++current_; } + + std::deque tempBatches; + size_t sets = 0; while(current_ != data_->end() && maxiBatch->size() < maxSize) { // loop over data + if (getSignalFlag(SIGTERM)) { // received SIGTERM, abandon ship ... + return tempBatches; + } maxiBatch->push(*current_); sets = current_->size(); // do not consume more than required for the maxi batch as this causes @@ -149,8 +156,6 @@ class BatchGenerator : public RNGEngine { size_t currentWords = 0; std::vector lengths(sets, 0); // records maximum length observed within current batch - std::deque tempBatches; - // process all loaded sentences in order of increasing length // @TODO: we could just use a vector and do a sort() here; would make the cost more explicit const size_t mbWords = options_->get("mini-batch-words", 0); @@ -158,7 +163,13 @@ class BatchGenerator : public RNGEngine { BatchStats::const_iterator cachedStatsIter; if (stats_) cachedStatsIter = stats_->begin(); + while(!maxiBatch->empty()) { // while there are sentences in the queue + + if (getSignalFlag(SIGTERM)) { // received SIGTERM, abandon ship ... + return tempBatches; + } + // push item onto batch batchVector.push_back(maxiBatch->top()); maxiBatch->pop(); // fetch next-shortest @@ -242,13 +253,13 @@ class BatchGenerator : public RNGEngine { ABORT_IF(!futureBufferedBatches_.valid(), "Attempted to wait for futureBufferedBatches_ when none pending.\n" "This error often occurs when Marian tries to restore the training data iterator, but the corpus has been changed or replaced.\n" "If you have changed the training corpus, add --no-restore-corpus to the training command and run it again."); + bufferedBatches_ = std::move(futureBufferedBatches_.get()); - // if bg thread returns an empty swath, we hit the end of the epoch - if (bufferedBatches_.empty()) { + if (bufferedBatches_.empty() // i.e., end of Epoch + || getSignalFlag(SIGTERM)) { // process received SIGTERM, abandon ship ... return nullptr; } - // and kick off the next bg operation - fetchBatchesAsync(); + fetchBatchesAsync(); // pre-fetch next slew of batches in separate thread } auto batch = bufferedBatches_.front(); bufferedBatches_.pop_front(); diff --git a/src/training/scheduler.cpp b/src/training/scheduler.cpp deleted file mode 100644 index 4c30cb04e..000000000 --- a/src/training/scheduler.cpp +++ /dev/null @@ -1,43 +0,0 @@ -#include "scheduler.h" -#include -#include - -namespace marian { - -// SIGNAL HANDLING, see scheduler.cpp for definitions -// Currently, only the following is handled by a custom signal handler: -// SIGTERM: When SIGTERM is received, the global (static member) flag sigterm_ (false by default) is set to true -// by signalHandler(). When sigterm_ is true, keepGoing() returns false, and the current state of training models -// is saved prior to exiting. -// This functionality is helpful when training on clusters with time limits on compute slots, e.g., on s -// clusters managed by slurm. Slurm can be asked to sending a (custom) warning signal to a process at a given -// point in time prior to the hard "time's up". - -bool sigterm_{false}; // flag signalling that SIGTERM has been received false by default, set to true by signalHandler(SIGTERM) - -void signalHandler(int sig) { - // Note: sys_siglist[sig] or stdsignal() describe the effect (e.g., - // 'Terminated' rather than provide the signal name (which are #define(s) - // in signal.h), so we have to do custom log messages here. - switch (sig) { - case SIGTERM: // save models and exit - LOG(info, "[training] Scheduler received signal SIGTERM"); // @TODO: figure out if this is safe. The logs are global and thread-safe, so should be OK? - sigterm_ = true; - break; - default: - ABORT("No action defined for signal {}", sig); - } -} - -// installs signalHandler() for select signals (currently only SIGTERM) -void installSignalHandlers() { - // TODO: use sigaction instead of signal, - // cf. https://stackoverflow.com/questions/231912/what-is-the-difference-between-sigaction-and-signal - signal(SIGTERM, signalHandler); -} - -bool getSigtermFlag() { - return sigterm_; -} - -} diff --git a/src/training/scheduler.h b/src/training/scheduler.h index 8c8701cac..2ec9f1ab0 100755 --- a/src/training/scheduler.h +++ b/src/training/scheduler.h @@ -1,6 +1,7 @@ #pragma once #include "common/options.h" +#include "common/signal_handling.h" #include "training/training_state.h" #include "training/validator.h" #include "training/communicator.h" @@ -8,9 +9,6 @@ namespace marian { -bool getSigtermFlag(); -void installSignalHandlers(); - class Scheduler : public TrainingObserver { private: Ptr options_; @@ -154,11 +152,10 @@ class Scheduler : public TrainingObserver { : options_(options), state_(state) { ABORT_IF(state_->factor != 1, "state.factor unexpectedly not 1 at this point??"); updateLearningRate(*state); - installSignalHandlers(); } bool keepGoing() { - if(getSigtermFlag()) // received signal SIGERM => exit gracefully + if(getSignalFlag(SIGTERM)) // received signal SIGERM => exit gracefully return false; // stop if it reached the maximum number of epochs @@ -192,7 +189,7 @@ class Scheduler : public TrainingObserver { void started() { LOG(info, "Training started"); } void finished() { - if (getSigtermFlag()) + if (getSignalFlag(SIGTERM)) LOG(info, "Training interrupted (SIGTERM)."); else LOG(info, "Training finished"); @@ -225,7 +222,7 @@ class Scheduler : public TrainingObserver { bool isFinal = false) { // Do not validate if already validated (for instance, after the model is // loaded) or if validation is scheduled for another update, or when signal SIGTERM was received - if(getSigtermFlag() // SIGTERM was received + if(getSignalFlag(SIGTERM) // SIGTERM was received || state_->validated // already validated (in resumed training, for example) || (!state_->enteredNewPeriodOf(options_->get("valid-freq")) && !isFinal)) // not now return; diff --git a/src/training/training.h b/src/training/training.h index 5a2be7635..c68602ec9 100644 --- a/src/training/training.h +++ b/src/training/training.h @@ -77,6 +77,12 @@ class Train : public ModelTask { bool restored = !options_->get("no-restore-corpus") && batchGenerator->restore(trainState); + // Install custom handler for SIGTERM, to allow for a graceful + // shutdown that saves the current state of training before exiting. + // This signal handler simply sets a flag that can be checked from + // everywhere (getSignalFLAG(SIGTERM); #include common/signal_handling.h) + signal(SIGTERM,setSignalFlag); + // -- main training loop scheduler->started(); while(scheduler->keepGoing()) { From 63006db5ac9359d3074a1b5753765a0417081950 Mon Sep 17 00:00:00 2001 From: Young Jin Kim Date: Wed, 25 Mar 2020 02:52:17 +0000 Subject: [PATCH 29/62] Merged PR 11831: Change the weight matrix quantization to use 7-bit min/max quantization to avoid overflow 1. Change the weight matrix quantization to use 7-bit min/max quantization -> This resolves all the overflow issue, because weight and activations are quantized by min/max range. 2. Clip fp16 quantization to avoid overflow 3. Fix windows build errors (cmake options, vcproj file) 4. int8 pack model (encoder -> fp16) --- .../cpu/fbgemm/expression_graph_packable.h | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/tensors/cpu/fbgemm/expression_graph_packable.h b/src/tensors/cpu/fbgemm/expression_graph_packable.h index f5b05c302..e45f8274f 100644 --- a/src/tensors/cpu/fbgemm/expression_graph_packable.h +++ b/src/tensors/cpu/fbgemm/expression_graph_packable.h @@ -10,7 +10,7 @@ namespace marian { // This requires some more changes, but we temporarily do this just by name ("_W") of the weights. // And, this introduces a low level packed_gemm.h apis interact with high level graph class. // So, we make a subclass of ExpressionGraph and put those immature codes in this class. -// We will improve this in the near future. +// We will improve this in the near future. class ExpressionGraphPackable : public ExpressionGraph { public: ExpressionGraphPackable() @@ -36,10 +36,11 @@ class ExpressionGraphPackable : public ExpressionGraph { // save as packed format // @TODO Hardcoded to find packable weights - // int8 - all the weights used for affine op and dot op - // fp16 - all the weights used for affine op + // int8 - quantize decoder only for better quality, all the weights used for affine op and dot op (int8) + // fp16 - all the weights used for affine op (fp16) if ((gemmElementType == Type::packed8avx2 || gemmElementType == Type::packed8avx512) - && (pName.find("_W") == pName.length() - 3 || pName.find("_W") == pName.length() - 2)) { + && (pName.find("_W") == pName.length() - 3 || pName.find("_W") == pName.length() - 2) + && pName.find("encoder") == std::string::npos) { #if USE_FBGEMM using namespace marian::cpu::variant; // packing information - size @@ -84,8 +85,10 @@ class ExpressionGraphPackable : public ExpressionGraph { #else ABORT("Packed type {} only supported when compiled with -DUSE_FBGEMM=on", gemmElementType); #endif - // fp16 quantization option - } else if (gemmElementType == Type::packed16 && pName.find("_W") == pName.length() - 3) { + // fp16 quantization option + encoders for int8 quantized models + } else if ((gemmElementType == Type::packed16 && pName.find("_W") == pName.length() - 3) + || ((gemmElementType == Type::packed8avx2 || gemmElementType == Type::packed8avx512) + && (pName.find("_W") == pName.length() - 3 || pName.find("_W") == pName.length() - 2))) { #if USE_FBGEMM using namespace marian::cpu::variant; @@ -153,4 +156,4 @@ class ExpressionGraphPackable : public ExpressionGraph { } }; -} // namespace marian \ No newline at end of file +} // namespace marian From 128e1fc19afbe93bbe1d80139e2f5853c6e9060f Mon Sep 17 00:00:00 2001 From: Young Jin Kim Date: Fri, 27 Mar 2020 21:44:31 +0000 Subject: [PATCH 30/62] Merged PR 12243: For int8 quantized model, use int8 quantization for encoders as well For int8 quantized model, use int8 quantization for encoders as well. The quality difference between fp16 encoder and int8 encoder is small, but they have quite amount of speed difference. --- src/tensors/cpu/fbgemm/expression_graph_packable.h | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/tensors/cpu/fbgemm/expression_graph_packable.h b/src/tensors/cpu/fbgemm/expression_graph_packable.h index e45f8274f..cbb459ed5 100644 --- a/src/tensors/cpu/fbgemm/expression_graph_packable.h +++ b/src/tensors/cpu/fbgemm/expression_graph_packable.h @@ -36,11 +36,10 @@ class ExpressionGraphPackable : public ExpressionGraph { // save as packed format // @TODO Hardcoded to find packable weights - // int8 - quantize decoder only for better quality, all the weights used for affine op and dot op (int8) - // fp16 - all the weights used for affine op (fp16) + // int8 - all the weights used for affine op and dot op + // fp16 - all the weights used for affine op if ((gemmElementType == Type::packed8avx2 || gemmElementType == Type::packed8avx512) - && (pName.find("_W") == pName.length() - 3 || pName.find("_W") == pName.length() - 2) - && pName.find("encoder") == std::string::npos) { + && (pName.find("_W") == pName.length() - 3 || pName.find("_W") == pName.length() - 2)) { #if USE_FBGEMM using namespace marian::cpu::variant; // packing information - size @@ -85,10 +84,8 @@ class ExpressionGraphPackable : public ExpressionGraph { #else ABORT("Packed type {} only supported when compiled with -DUSE_FBGEMM=on", gemmElementType); #endif - // fp16 quantization option + encoders for int8 quantized models - } else if ((gemmElementType == Type::packed16 && pName.find("_W") == pName.length() - 3) - || ((gemmElementType == Type::packed8avx2 || gemmElementType == Type::packed8avx512) - && (pName.find("_W") == pName.length() - 3 || pName.find("_W") == pName.length() - 2))) { + // fp16 quantization option + } else if (gemmElementType == Type::packed16 && pName.find("_W") == pName.length() - 3) { #if USE_FBGEMM using namespace marian::cpu::variant; From 98dff9d26ba0091848f83d91ca25960074977a77 Mon Sep 17 00:00:00 2001 From: Roman Grundkiewicz Date: Fri, 10 Apr 2020 21:01:56 +0100 Subject: [PATCH 31/62] Support tab-separated inputs (#617) * Add basic support for TSV inputs * Fix mini-batch-fit for TSV inputs * Abort if shuffling data from stdin * Fix terminating training with data from STDIN * Allow creating vocabs from TSV files * Add comments; clean creation of vocabs from TSV files * Guess --tsv-size based on the model type * Add shortcut for STDIN inputs * Rename --tsv-size to --tsv-fields * Allow only one 'stdin' in --train-sets * Properly create separate vocabularies from a TSV file * Clearer logging message * Add error message for wrong number of valid sets if --tsv is used * Use --no-shuffle instead of --shuffle in the error message * Fix continuing training from STDIN * Update CHANGELOG * Support both 'stdin' and '-' * Guess --tsv-fields from dim-vocabs if special:model.yml available * Update error messages * Move variable outside the loop * Refactorize utils::splitTsv; add unit tests * Support '-' as stdin; refactorize; add comments * Abort if excessive field(s) in the TSV input * Add a TODO on passing one vocab with fully-tied embeddings * Remove the unit test with excessive tab-separated fields --- CHANGELOG.md | 2 + src/CMakeLists.txt | 1 + src/common/config_parser.cpp | 15 +++++++ src/training/graph_group.h | 83 +++++++++++++++++++++++++++++++++++- 4 files changed, 100 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bdf07d461..5978775e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [1.9.0] - 2020-03-10 ### Added +- Training and scoring from STDIN +- Support for tab-separated inputs, added ptions --tsv and --tsv-fields - An option to print cached variables from CMake - Add support for compiling on Mac (and clang) - An option for resetting stalled validation metrics diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b78431d6a..79886ed0d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -26,6 +26,7 @@ add_library(marian STATIC common/file_stream.cpp common/file_utils.cpp common/signal_handling.cpp + common/file_utils.cpp common/types.cpp data/alignment.cpp diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index 2f56d8870..9224666f1 100755 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -892,6 +892,21 @@ Ptr ConfigParser::parseOptions(int argc, char** argv, bool doValidate){ cli_.updateConfig(config, cli::OptionPriority::CommandLine, "A shortcut for STDIN failed."); } + // Option shortcuts for input from STDIN for trainer and scorer + if(mode_ == cli::mode::training || mode_ == cli::mode::scoring) { + auto trainSets = get>("train-sets"); + YAML::Node config; + // Assume the input will come from STDIN if --tsv is set but no --train-sets are given + if(get("tsv") && trainSets.empty()) { + config["train-sets"].push_back("stdin"); + // Assume the input is in TSV format if --train-sets is set to "stdin" + } else if(trainSets.size() == 1 && (trainSets[0] == "stdin" || trainSets[0] == "-")) { + config["tsv"] = true; + } + if(!config.IsNull()) + cli_.updateConfig(config, cli::OptionPriority::CommandLine, "A shortcut for STDIN failed."); + } + if(doValidate) { ConfigValidator(config_).validateOptions(mode_); } diff --git a/src/training/graph_group.h b/src/training/graph_group.h index 012f78ef9..83873edab 100644 --- a/src/training/graph_group.h +++ b/src/training/graph_group.h @@ -55,7 +55,88 @@ class GraphGroup { Ptr collectStats(Ptr graph, Ptr model, const std::vector>& vocabs, - double multiplier = 1.); + double multiplier = 1.) { + auto stats = New(); + + size_t numFiles = options_->get("tsv", false) + ? options_->get("tsv-fields") + : options_->get>("train-sets").size(); + + // Initialize first batch to step size + size_t first = options_->get("mini-batch-fit-step"); + + // Increase batch size and sentence length by this step size + size_t step = options_->get("mini-batch-fit-step"); + + size_t maxLength = options_->get("max-length"); + maxLength = (size_t)(std::ceil(maxLength / (float)step) * step); + + // this should be only one class label per line on input, hence restricting length to 1 + std::vector localMaxes(numFiles, maxLength); + auto inputTypes = options_->get>("input-types", {}); + for(int i = 0; i < inputTypes.size(); ++i) + if(inputTypes[i] == "class") + localMaxes[i] = 1; + + size_t maxBatch = 512; + bool fits = true; + while(fits) { + std::vector lengths(numFiles, first); + for(int j = 0; j < lengths.size(); ++j) // apply length restrictions + lengths[j] = std::min(lengths[j], localMaxes[j]); + + auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, maxBatch, options_); + auto cost = model->build(graph, batch); + fits = graph->fits(); + if(fits) + maxBatch *= 2; + } + + // Do a binary search for maxmimum batch size that fits into given workspace memory + // for a tested sentence length. + for(size_t i = step; i <= maxLength; i += step) { + size_t start = 1; + size_t end = maxBatch; + + std::vector lengths(numFiles, i); + for(int j = 0; j < lengths.size(); ++j) // apply length restrictions + lengths[j] = std::min(lengths[j], localMaxes[j]); + fits = true; + + do { + size_t current = (start + end) / 2; + auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, current, options_); + auto cost = model->build(graph, batch); + fits = graph->fits(); + + LOG(debug, "[batching] length: {} - size: {} - fits: {}", lengths[0], current, fits); + + if(fits) { + stats->add(batch, multiplier); + start = current + 1; + } else { + end = current - 1; + } + } while(end - start > step); + + maxBatch = start; + } + return stats; + } + + void setTypicalTrgBatchWords(size_t typicalTrgBatchWords) { // needed for dynamic MB scaling + typicalTrgBatchWords_ = typicalTrgBatchWords; + } +}; + +/** + * Base class for multi-node versions of GraphGroups. + */ +class MultiNodeGraphGroupBase : public GraphGroup { + using Base = GraphGroup; + +protected: + Ptr mpi_; // all MPI-like communication goes through this void setTypicalTrgBatchWords(size_t typicalTrgBatchWords); }; From b06531dff96b75f5478e68cedf040f9fa2fc0895 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Fri, 10 Apr 2020 13:53:21 -0700 Subject: [PATCH 32/62] actually save the merge file --- src/training/graph_group.h | 83 +------------------------------------- 1 file changed, 1 insertion(+), 82 deletions(-) diff --git a/src/training/graph_group.h b/src/training/graph_group.h index 83873edab..012f78ef9 100644 --- a/src/training/graph_group.h +++ b/src/training/graph_group.h @@ -55,88 +55,7 @@ class GraphGroup { Ptr collectStats(Ptr graph, Ptr model, const std::vector>& vocabs, - double multiplier = 1.) { - auto stats = New(); - - size_t numFiles = options_->get("tsv", false) - ? options_->get("tsv-fields") - : options_->get>("train-sets").size(); - - // Initialize first batch to step size - size_t first = options_->get("mini-batch-fit-step"); - - // Increase batch size and sentence length by this step size - size_t step = options_->get("mini-batch-fit-step"); - - size_t maxLength = options_->get("max-length"); - maxLength = (size_t)(std::ceil(maxLength / (float)step) * step); - - // this should be only one class label per line on input, hence restricting length to 1 - std::vector localMaxes(numFiles, maxLength); - auto inputTypes = options_->get>("input-types", {}); - for(int i = 0; i < inputTypes.size(); ++i) - if(inputTypes[i] == "class") - localMaxes[i] = 1; - - size_t maxBatch = 512; - bool fits = true; - while(fits) { - std::vector lengths(numFiles, first); - for(int j = 0; j < lengths.size(); ++j) // apply length restrictions - lengths[j] = std::min(lengths[j], localMaxes[j]); - - auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, maxBatch, options_); - auto cost = model->build(graph, batch); - fits = graph->fits(); - if(fits) - maxBatch *= 2; - } - - // Do a binary search for maxmimum batch size that fits into given workspace memory - // for a tested sentence length. - for(size_t i = step; i <= maxLength; i += step) { - size_t start = 1; - size_t end = maxBatch; - - std::vector lengths(numFiles, i); - for(int j = 0; j < lengths.size(); ++j) // apply length restrictions - lengths[j] = std::min(lengths[j], localMaxes[j]); - fits = true; - - do { - size_t current = (start + end) / 2; - auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, current, options_); - auto cost = model->build(graph, batch); - fits = graph->fits(); - - LOG(debug, "[batching] length: {} - size: {} - fits: {}", lengths[0], current, fits); - - if(fits) { - stats->add(batch, multiplier); - start = current + 1; - } else { - end = current - 1; - } - } while(end - start > step); - - maxBatch = start; - } - return stats; - } - - void setTypicalTrgBatchWords(size_t typicalTrgBatchWords) { // needed for dynamic MB scaling - typicalTrgBatchWords_ = typicalTrgBatchWords; - } -}; - -/** - * Base class for multi-node versions of GraphGroups. - */ -class MultiNodeGraphGroupBase : public GraphGroup { - using Base = GraphGroup; - -protected: - Ptr mpi_; // all MPI-like communication goes through this + double multiplier = 1.); void setTypicalTrgBatchWords(size_t typicalTrgBatchWords); }; From 5ce67c6f31c58f47b5a0a53534e74d66acd91bee Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Fri, 10 Apr 2020 15:27:34 -0700 Subject: [PATCH 33/62] use float values for catch::Approx --- src/tests/units/attention_tests.cpp | 2 +- src/tests/units/operator_tests.cpp | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/tests/units/attention_tests.cpp b/src/tests/units/attention_tests.cpp index fe11bf2f0..4fbed7b52 100644 --- a/src/tests/units/attention_tests.cpp +++ b/src/tests/units/attention_tests.cpp @@ -23,7 +23,7 @@ void tests(DeviceType type, Type floatType = Type::float32) { } #endif - auto floatApprox = [](T x, T y) { return x == Approx(y).margin(0.001f); }; + auto floatApprox = [](T x, T y) { return x == Approx(y).epsilon(0.01f).scale(1.f); }; Config::seed = 1234; diff --git a/src/tests/units/operator_tests.cpp b/src/tests/units/operator_tests.cpp index 581cd05c7..06d862328 100644 --- a/src/tests/units/operator_tests.cpp +++ b/src/tests/units/operator_tests.cpp @@ -22,7 +22,7 @@ void tests(DeviceType device, Type floatType = Type::float32) { } #endif - auto floatApprox = [](T x, T y) -> bool { return x == Approx(y).margin(0.001f); }; + auto floatApprox = [](T x, T y) -> bool { return x == Approx(y).epsilon(0.01f).scale(1.f); }; auto floatEqual = [](T x, T y) -> bool { return x == y; }; Config::seed = 1234; @@ -794,8 +794,8 @@ TEST_CASE("Expression graph supports basic math operations (cpu)", "[operator]") #ifdef CUDA_FOUND TEST_CASE("Compare aggregate operator", "[graph]") { - auto floatApprox = [](float x, float y) -> bool { return x == Approx(y).margin(0.001f); }; - + auto floatApprox = [](float x, float y) -> bool { return x == Approx(y).epsilon(0.01f).scale(1.f); }; + Config::seed = 1234; std::vector initc; @@ -817,7 +817,7 @@ TEST_CASE("Compare aggregate operator", "[graph]") { SECTION("initializing with zero (cpu)") { std::vector values1; std::vector values2; - + auto graph1 = New(); graph1->setDevice({0, DeviceType::cpu}); graph1->reserveWorkspaceMB(40); @@ -825,7 +825,7 @@ TEST_CASE("Compare aggregate operator", "[graph]") { auto graph2 = New(); graph2->setDevice({0, DeviceType::gpu}); graph2->reserveWorkspaceMB(40); - + auto chl1 = graph1->param("1x10x512x2048", {1, 10, 512, 2048}, inits::fromVector(initc)); auto adj1 = graph1->param("1x1x512x2048", {1, 1, 512, 2048}, inits::fromVector(inita)); auto prod1 = scalar_product(chl1, adj1, -1); @@ -844,4 +844,4 @@ TEST_CASE("Compare aggregate operator", "[graph]") { } #endif - #endif \ No newline at end of file + #endif From 709522c788b63cd22fe95d950c90519b331b5e58 Mon Sep 17 00:00:00 2001 From: Roman Grundkiewicz Date: Sat, 11 Apr 2020 16:04:20 +0100 Subject: [PATCH 34/62] Fix TSV training with mini-batch-fit after the last merge --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5978775e9..1cd653b6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - In concatenation make sure that we do not multiply 0 with nan (which results in nan) - Change Approx.epsilon(0.01) to Approx.margin(0.001) in unit tests. Tolerance is now absolute and not relative. We assumed incorrectly that epsilon is absolute tolerance. +- Training and scoring from STDIN +- Support for tab-separated inputs, added options --tsv and --tsv-fields ### Changed - Move Simple-WebSocket-Server to submodule @@ -34,8 +36,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [1.9.0] - 2020-03-10 ### Added -- Training and scoring from STDIN -- Support for tab-separated inputs, added ptions --tsv and --tsv-fields - An option to print cached variables from CMake - Add support for compiling on Mac (and clang) - An option for resetting stalled validation metrics From 17167dd9fb455f0fb334c3ca681275cc733f1f88 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Sat, 11 Apr 2020 09:45:57 -0700 Subject: [PATCH 35/62] Fix 0 * nan behavior due to using -O3 instead of -OFast (#630) * fix 0 * nan behavior in concatention * bump patch * change epsilon to margin --- CHANGELOG.md | 8 +++++++- src/tests/units/attention_tests.cpp | 2 +- src/tests/units/operator_tests.cpp | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1cd653b6e..06329224d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,13 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Change Approx.epsilon(0.01) to Approx.margin(0.001) in unit tests. Tolerance is now absolute and not relative. We assumed incorrectly that epsilon is absolute tolerance. - Training and scoring from STDIN -- Support for tab-separated inputs, added options --tsv and --tsv-fields +- Support for reading from TSV files from STDIN and other sources during training + and translation with options --tsv and --tsv-fields n. + +### Fixed +- In concatenation make sure that we do not multiply 0 with nan (which results in nan) +- Change Approx.epsilon(0.01) to Approx.margin(0.001) in unit tests. Tolerance is now + absolute and not relative. We assumed incorrectly that epsilon is absolute tolerance. ### Changed - Move Simple-WebSocket-Server to submodule diff --git a/src/tests/units/attention_tests.cpp b/src/tests/units/attention_tests.cpp index 4fbed7b52..fe11bf2f0 100644 --- a/src/tests/units/attention_tests.cpp +++ b/src/tests/units/attention_tests.cpp @@ -23,7 +23,7 @@ void tests(DeviceType type, Type floatType = Type::float32) { } #endif - auto floatApprox = [](T x, T y) { return x == Approx(y).epsilon(0.01f).scale(1.f); }; + auto floatApprox = [](T x, T y) { return x == Approx(y).margin(0.001f); }; Config::seed = 1234; diff --git a/src/tests/units/operator_tests.cpp b/src/tests/units/operator_tests.cpp index 06d862328..58acd4b1f 100644 --- a/src/tests/units/operator_tests.cpp +++ b/src/tests/units/operator_tests.cpp @@ -22,7 +22,7 @@ void tests(DeviceType device, Type floatType = Type::float32) { } #endif - auto floatApprox = [](T x, T y) -> bool { return x == Approx(y).epsilon(0.01f).scale(1.f); }; + auto floatApprox = [](T x, T y) -> bool { return x == Approx(y).margin(0.001f); }; auto floatEqual = [](T x, T y) -> bool { return x == y; }; Config::seed = 1234; From 65c9c449a7179574a5a7a3580327597a540704c3 Mon Sep 17 00:00:00 2001 From: Roman Grundkiewicz Date: Sun, 12 Apr 2020 18:56:11 +0100 Subject: [PATCH 36/62] Support relative paths in shortlist and sqlite options (#612) * Refactorize processPaths * Fix relative paths for shortlist and sqlite options * Rename InterpolateEnvVars to interpolateEnvVars * Update CHANGELOG --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 06329224d..d19b43800 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,8 +21,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - In concatenation make sure that we do not multiply 0 with nan (which results in nan) - Change Approx.epsilon(0.01) to Approx.margin(0.001) in unit tests. Tolerance is now absolute and not relative. We assumed incorrectly that epsilon is absolute tolerance. +- Supporting relative paths in shortlist and sqlite options - Training and scoring from STDIN -- Support for reading from TSV files from STDIN and other sources during training +- Support for reading from TSV files from STDIN and other sources during training and translation with options --tsv and --tsv-fields n. ### Fixed From 1312c18a21833dde303dc9fdfd3e638a7249c0e2 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Mon, 13 Apr 2020 17:31:06 -0700 Subject: [PATCH 37/62] update changelog and version --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d19b43800..45c267cf7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. and translation with options --tsv and --tsv-fields n. ### Fixed +- Make mini-batch scaling depend on mini-batch-words and not on mini-batch-words-ref - In concatenation make sure that we do not multiply 0 with nan (which results in nan) - Change Approx.epsilon(0.01) to Approx.margin(0.001) in unit tests. Tolerance is now absolute and not relative. We assumed incorrectly that epsilon is absolute tolerance. From 2f2a00b524b5a98817d0a8d422efd3b776646f96 Mon Sep 17 00:00:00 2001 From: Roman Grundkiewicz Date: Mon, 27 Apr 2020 10:34:10 +0100 Subject: [PATCH 38/62] Update Simple-WebSocket-Server and move it to submodules (#639) * Fix server build with current boost, move simple-websocket-server to submodule * Change submodule to marian-nmt/Simple-WebSocket-Server * Update submodule simple-websocket-server Co-authored-by: Gleb Tv --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 45c267cf7..4ad7f937f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. and translation with options --tsv and --tsv-fields n. ### Fixed +- Fix building server with Boost 1.72 - Make mini-batch scaling depend on mini-batch-words and not on mini-batch-words-ref - In concatenation make sure that we do not multiply 0 with nan (which results in nan) - Change Approx.epsilon(0.01) to Approx.margin(0.001) in unit tests. Tolerance is now From d2d35639c415c915092622cb6b0e01d27a8155d8 Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Fri, 22 May 2020 15:46:13 +0100 Subject: [PATCH 39/62] Post-rebase fixes. --- CHANGELOG.md | 15 ++------------- src/CMakeLists.txt | 1 - src/command/marian_train.cpp | 4 ++-- src/common/config_parser.cpp | 15 --------------- .../cpu/fbgemm/expression_graph_packable.h | 4 ++-- src/tests/units/operator_tests.cpp | 10 +++++----- 6 files changed, 11 insertions(+), 38 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ad7f937f..7926c17ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,17 +21,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - In concatenation make sure that we do not multiply 0 with nan (which results in nan) - Change Approx.epsilon(0.01) to Approx.margin(0.001) in unit tests. Tolerance is now absolute and not relative. We assumed incorrectly that epsilon is absolute tolerance. -- Supporting relative paths in shortlist and sqlite options -- Training and scoring from STDIN -- Support for reading from TSV files from STDIN and other sources during training - and translation with options --tsv and --tsv-fields n. - -### Fixed -- Fix building server with Boost 1.72 -- Make mini-batch scaling depend on mini-batch-words and not on mini-batch-words-ref -- In concatenation make sure that we do not multiply 0 with nan (which results in nan) -- Change Approx.epsilon(0.01) to Approx.margin(0.001) in unit tests. Tolerance is now - absolute and not relative. We assumed incorrectly that epsilon is absolute tolerance. +- Improved handling for graceful shutdown upon receiving SIGTERM. + SIGTERM now also interrupts batch prefetching, which runs in a separate thread. ### Changed - Move Simple-WebSocket-Server to submodule @@ -39,8 +30,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Changed compile flags -Ofast to -O3 and remove --ffinite-math - Moved old graph groups to depracated folder - Make cublas and cusparse handle inits lazy to save memory when unused -- Improved handling for graceful shutdown upon receiving SIGTERM. - SIGTERM now also interrupts batch prefetching, which runs in a separate thread. ## [1.9.0] - 2020-03-10 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 79886ed0d..b78431d6a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -26,7 +26,6 @@ add_library(marian STATIC common/file_stream.cpp common/file_utils.cpp common/signal_handling.cpp - common/file_utils.cpp common/types.cpp data/alignment.cpp diff --git a/src/command/marian_train.cpp b/src/command/marian_train.cpp index 46bd05e84..fad6656bb 100644 --- a/src/command/marian_train.cpp +++ b/src/command/marian_train.cpp @@ -48,9 +48,9 @@ int main(int argc, char** argv) { // for bash in http://tldp.org/LDP/abs/html/exitcodes.html. This allows parent // scripts to determine if training terminated naturally or via SIGTERM. // Whith this approach we can accommodate additional signals in the future. - // An alternative would be to return 124, which is what the timeout command + // An alternative would be to exit with code 124, which is what the timeout command // returns for timeout -s SIGTERM ...., because exiting after SIGTERM // is not technically a fatal error (which is what the 128+x convention usually // stands for). - return getSignalFlag(SIGTERM) ? (128 + SIGTERM) : 0; + exit(getSignalFlag(SIGTERM) ? (128 + SIGTERM) : EXIT_SUCCESS); } diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index 9224666f1..2f56d8870 100755 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -892,21 +892,6 @@ Ptr ConfigParser::parseOptions(int argc, char** argv, bool doValidate){ cli_.updateConfig(config, cli::OptionPriority::CommandLine, "A shortcut for STDIN failed."); } - // Option shortcuts for input from STDIN for trainer and scorer - if(mode_ == cli::mode::training || mode_ == cli::mode::scoring) { - auto trainSets = get>("train-sets"); - YAML::Node config; - // Assume the input will come from STDIN if --tsv is set but no --train-sets are given - if(get("tsv") && trainSets.empty()) { - config["train-sets"].push_back("stdin"); - // Assume the input is in TSV format if --train-sets is set to "stdin" - } else if(trainSets.size() == 1 && (trainSets[0] == "stdin" || trainSets[0] == "-")) { - config["tsv"] = true; - } - if(!config.IsNull()) - cli_.updateConfig(config, cli::OptionPriority::CommandLine, "A shortcut for STDIN failed."); - } - if(doValidate) { ConfigValidator(config_).validateOptions(mode_); } diff --git a/src/tensors/cpu/fbgemm/expression_graph_packable.h b/src/tensors/cpu/fbgemm/expression_graph_packable.h index cbb459ed5..f5b05c302 100644 --- a/src/tensors/cpu/fbgemm/expression_graph_packable.h +++ b/src/tensors/cpu/fbgemm/expression_graph_packable.h @@ -10,7 +10,7 @@ namespace marian { // This requires some more changes, but we temporarily do this just by name ("_W") of the weights. // And, this introduces a low level packed_gemm.h apis interact with high level graph class. // So, we make a subclass of ExpressionGraph and put those immature codes in this class. -// We will improve this in the near future. +// We will improve this in the near future. class ExpressionGraphPackable : public ExpressionGraph { public: ExpressionGraphPackable() @@ -153,4 +153,4 @@ class ExpressionGraphPackable : public ExpressionGraph { } }; -} // namespace marian +} // namespace marian \ No newline at end of file diff --git a/src/tests/units/operator_tests.cpp b/src/tests/units/operator_tests.cpp index 58acd4b1f..581cd05c7 100644 --- a/src/tests/units/operator_tests.cpp +++ b/src/tests/units/operator_tests.cpp @@ -794,8 +794,8 @@ TEST_CASE("Expression graph supports basic math operations (cpu)", "[operator]") #ifdef CUDA_FOUND TEST_CASE("Compare aggregate operator", "[graph]") { - auto floatApprox = [](float x, float y) -> bool { return x == Approx(y).epsilon(0.01f).scale(1.f); }; - + auto floatApprox = [](float x, float y) -> bool { return x == Approx(y).margin(0.001f); }; + Config::seed = 1234; std::vector initc; @@ -817,7 +817,7 @@ TEST_CASE("Compare aggregate operator", "[graph]") { SECTION("initializing with zero (cpu)") { std::vector values1; std::vector values2; - + auto graph1 = New(); graph1->setDevice({0, DeviceType::cpu}); graph1->reserveWorkspaceMB(40); @@ -825,7 +825,7 @@ TEST_CASE("Compare aggregate operator", "[graph]") { auto graph2 = New(); graph2->setDevice({0, DeviceType::gpu}); graph2->reserveWorkspaceMB(40); - + auto chl1 = graph1->param("1x10x512x2048", {1, 10, 512, 2048}, inits::fromVector(initc)); auto adj1 = graph1->param("1x1x512x2048", {1, 1, 512, 2048}, inits::fromVector(inita)); auto prod1 = scalar_product(chl1, adj1, -1); @@ -844,4 +844,4 @@ TEST_CASE("Compare aggregate operator", "[graph]") { } #endif - #endif + #endif \ No newline at end of file From 69f8192af84e5cbc031e28e1b0703c6ce0fd736b Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Sun, 26 Jul 2020 09:12:59 +0100 Subject: [PATCH 40/62] Update training.h Insert missing space in line 84, responding to @emjotde's 'nit'. --- src/training/training.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/training/training.h b/src/training/training.h index c68602ec9..9130bc743 100644 --- a/src/training/training.h +++ b/src/training/training.h @@ -81,7 +81,7 @@ class Train : public ModelTask { // shutdown that saves the current state of training before exiting. // This signal handler simply sets a flag that can be checked from // everywhere (getSignalFLAG(SIGTERM); #include common/signal_handling.h) - signal(SIGTERM,setSignalFlag); + signal(SIGTERM, setSignalFlag); // -- main training loop scheduler->started(); From 27cba8fcd309d9d84e603c99fbb785b001c40288 Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Sun, 2 Aug 2020 21:38:36 +0100 Subject: [PATCH 41/62] Fix space after comma. --- src/training/training.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/training/training.h b/src/training/training.h index c68602ec9..9130bc743 100644 --- a/src/training/training.h +++ b/src/training/training.h @@ -81,7 +81,7 @@ class Train : public ModelTask { // shutdown that saves the current state of training before exiting. // This signal handler simply sets a flag that can be checked from // everywhere (getSignalFLAG(SIGTERM); #include common/signal_handling.h) - signal(SIGTERM,setSignalFlag); + signal(SIGTERM, setSignalFlag); // -- main training loop scheduler->started(); From 286b980155d7986db8df40779a21071e85e3154b Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Tue, 4 Aug 2020 00:07:22 +0100 Subject: [PATCH 42/62] Update batch_generator.h Return empty deque after SIGTERM. This makes no practical difference whatsoever but was requested by the code reviewers. --- src/data/batch_generator.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/data/batch_generator.h b/src/data/batch_generator.h index 7801073f4..bc92c3513 100644 --- a/src/data/batch_generator.h +++ b/src/data/batch_generator.h @@ -139,7 +139,7 @@ class BatchGenerator : public RNGEngine { size_t sets = 0; while(current_ != data_->end() && maxiBatch->size() < maxSize) { // loop over data if (getSignalFlag(SIGTERM)) { // received SIGTERM, abandon ship ... - return tempBatches; + return std::deque(); } maxiBatch->push(*current_); sets = current_->size(); @@ -167,7 +167,7 @@ class BatchGenerator : public RNGEngine { while(!maxiBatch->empty()) { // while there are sentences in the queue if (getSignalFlag(SIGTERM)) { // received SIGTERM, abandon ship ... - return tempBatches; + return std::deque(); } // push item onto batch From 0a406d867288d44945f47403009bb2a517123925 Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Tue, 4 Aug 2020 00:41:30 +0100 Subject: [PATCH 43/62] Update batch_generator.h Frank doesn't like curly braces. --- src/data/batch_generator.h | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/data/batch_generator.h b/src/data/batch_generator.h index bc92c3513..f3248e66e 100644 --- a/src/data/batch_generator.h +++ b/src/data/batch_generator.h @@ -138,9 +138,8 @@ class BatchGenerator : public RNGEngine { size_t sets = 0; while(current_ != data_->end() && maxiBatch->size() < maxSize) { // loop over data - if (getSignalFlag(SIGTERM)) { // received SIGTERM, abandon ship ... + if (getSignalFlag(SIGTERM)) // received SIGTERM, abandon ship ... return std::deque(); - } maxiBatch->push(*current_); sets = current_->size(); // do not consume more than required for the maxi batch as this causes @@ -165,11 +164,8 @@ class BatchGenerator : public RNGEngine { cachedStatsIter = stats_->begin(); while(!maxiBatch->empty()) { // while there are sentences in the queue - - if (getSignalFlag(SIGTERM)) { // received SIGTERM, abandon ship ... + if (getSignalFlag(SIGTERM)) // received SIGTERM, abandon ship ... return std::deque(); - } - // push item onto batch batchVector.push_back(maxiBatch->top()); maxiBatch->pop(); // fetch next-shortest From 521560b4b264c475fc1f12278afae173b4a33030 Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Tue, 4 Aug 2020 01:58:19 +0100 Subject: [PATCH 44/62] Update signal_handling.h Edited comments. --- src/common/signal_handling.h | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/common/signal_handling.h b/src/common/signal_handling.h index 0de7db1f8..69e6588e5 100644 --- a/src/common/signal_handling.h +++ b/src/common/signal_handling.h @@ -3,16 +3,16 @@ // SIGNAL HANDLING -// The Marian signal handlers set global flags that thread can +// The Marian signal handlers set global flags that threads can // consider when a signal is received. This can be used for a graceful // shutdown instead of a hard abandonment, e.g. after receiving // SIGTERM during training. -// When SIGTERM is received, the global (static member) flag sigterm_ -// (false by default) is set to true by signalHandler(). When sigterm_ -// is true, keepGoing() returns false, and the current state of -// training models is saved prior to exiting. This functionality is -// helpful when training on clusters with time limits on compute +// When SIGTERM is received, the global bit flag for SIGTERM +// (false by default) is set to true by setSignalFlag. Threads involved in training +// (batch generator, scheduler) periodicly check this flag and gracefully exit; +// the current state of training models is then saved prior to exiting the program. +// This functionality is helpful when training on clusters with time limits on compute // slots, e.g., on s clusters managed by slurm. Slurm can be asked to // sending a (custom) warning signal to a process at a given point in // time prior to the hard "time's up". @@ -20,8 +20,7 @@ // Correspondingly, fetchBatches in the batch generator checks the flag // frequently and quits after the overall process receives a SIGTERM. - namespace marian { bool getSignalFlag(int sig); // return true if sig was received, false otherwise -void setSignalFlag(int sig); // set custom handler (set flag) for sig +void setSignalFlag(int sig); // custom handler (set flag) for sig } From d4102cbfe4ff3676bbaaf3cc3f4d670c3c835413 Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Wed, 5 Aug 2020 16:33:30 +0100 Subject: [PATCH 45/62] Update comments in signal_handling.h --- src/common/signal_handling.h | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/common/signal_handling.h b/src/common/signal_handling.h index 69e6588e5..46cad49c2 100644 --- a/src/common/signal_handling.h +++ b/src/common/signal_handling.h @@ -3,24 +3,24 @@ // SIGNAL HANDLING -// The Marian signal handlers set global flags that threads can -// consider when a signal is received. This can be used for a graceful -// shutdown instead of a hard abandonment, e.g. after receiving -// SIGTERM during training. - -// When SIGTERM is received, the global bit flag for SIGTERM -// (false by default) is set to true by setSignalFlag. Threads involved in training -// (batch generator, scheduler) periodicly check this flag and gracefully exit; -// the current state of training models is then saved prior to exiting the program. -// This functionality is helpful when training on clusters with time limits on compute -// slots, e.g., on s clusters managed by slurm. Slurm can be asked to -// sending a (custom) warning signal to a process at a given point in -// time prior to the hard "time's up". -// -// Correspondingly, fetchBatches in the batch generator checks the flag -// frequently and quits after the overall process receives a SIGTERM. +// The Marian signal handler setSignalFlag is a general purpose signal handler +// that sets a global flag upon receiving a signal (with SIGNAL No. < 32) in line +// with the recommendations for signal handling in the SEI CERT C Coding Standard, specifically +// - SIG30-C: https://wiki.sei.cmu.edu/confluence/display/c/SIG30-C.+Call+only+asynchronous-safe+functions+within+signal+handlers +// - SIG31-C: https://wiki.sei.cmu.edu/confluence/display/c/SIG31-C.+Do+not+access+shared+objects+in+signal+handlers +// Usage: +// - install the signal handler for a specific signal with signal(SIGNAL, setSignalFlag), +// e.g. signal(SIGTERM, setSignalFlag) +// - check the flag wherever appropriate with getSignalFlag(SIGNAL), +// e.g. getSignalFlag(SIGTERM) +// +// This mechanism is currently used in marian training to ensure a graceful shutdown after receiving +// SIGTERM, saving the current state of training before exiting. This behavior is particularly desirable +// when training on clusters with time limits on computeslots, e.g., on certain clusters managed by slurm. +// Slurm can be asked to send a (custom) warning signal to a process at a certain time priopr to the +// hard end of the time slot. namespace marian { bool getSignalFlag(int sig); // return true if sig was received, false otherwise void setSignalFlag(int sig); // custom handler (set flag) for sig -} +} // end of namespace marian From 5d80ab445402ce39ecd983c68c0619e4295ed663 Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Wed, 19 Aug 2020 13:39:11 +0100 Subject: [PATCH 46/62] Configurable signal handlers for SIGTERM - Adds custom signal handler option for SIGTERM (--sigterm) to marian training. - default behavior: 'graceful' => graceful exit (save model, then exit) - alternative: 'immediate' => exit immediately without saving model --- src/command/marian_train.cpp | 8 ++---- src/common/config_parser.cpp | 12 +++++++- src/common/signal_handling.cpp | 14 +++++++++- src/common/signal_handling.h | 51 +++++++++++++++++++++------------- src/data/batch_generator.h | 11 ++++---- src/training/scheduler.h | 12 ++++---- src/training/training.h | 21 ++++++++++---- 7 files changed, 86 insertions(+), 43 deletions(-) diff --git a/src/command/marian_train.cpp b/src/command/marian_train.cpp index fad6656bb..f9c6492d9 100644 --- a/src/command/marian_train.cpp +++ b/src/command/marian_train.cpp @@ -43,14 +43,12 @@ int main(int argc, char** argv) { New>(options)->run(); } } - - // If we exit due to SIGTERM, exit with 128 + the signal number, as suggested - // for bash in http://tldp.org/LDP/abs/html/exitcodes.html. This allows parent + // If we exit due to a graceful exit request via SIGTERM, exit with 128 + SIGTERM, + // as suggested for bash in http://tldp.org/LDP/abs/html/exitcodes.html. This allows parent // scripts to determine if training terminated naturally or via SIGTERM. - // Whith this approach we can accommodate additional signals in the future. // An alternative would be to exit with code 124, which is what the timeout command // returns for timeout -s SIGTERM ...., because exiting after SIGTERM // is not technically a fatal error (which is what the 128+x convention usually // stands for). - exit(getSignalFlag(SIGTERM) ? (128 + SIGTERM) : EXIT_SUCCESS); + exit(getSignalFlag(SIGTERM) ? 128 + SIGTERM : EXIT_SUCCESS); } diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index f2ac41bca..1d6493586 100755 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -318,13 +318,23 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) { "Dropout for transformer attention (0 = no dropout)"); cli.add("--transformer-dropout-ffn", "Dropout for transformer filter (0 = no dropout)"); + } cli.switchGroup(previous_group); // clang-format on } void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) { - auto previous_group = cli.switchGroup("Training options"); + auto previous_group = cli.switchGroup("Signal Handling"); + // --sigterm is deliberately not a boolean, to allow for a consistent + // pattern of specifying custom signal handling in the future. + // (e.g., dump model but continue training upon SIGUSR1, or report current + // training status upon SIGINFO.) + cli.add("--sigterm", + "What to do with SIGTERM: 'graceful' => save and exit (default); " + "'immediate' => exit immediately.", "graceful"); + + cli.switchGroup("Training options"); // clang-format off cli.add("--cost-type", // @TODO: rename to loss-type "Optimization criterion: ce-mean, ce-mean-words, ce-sum, perplexity", "ce-mean"); diff --git a/src/common/signal_handling.cpp b/src/common/signal_handling.cpp index a18d1e669..889162190 100644 --- a/src/common/signal_handling.cpp +++ b/src/common/signal_handling.cpp @@ -7,15 +7,27 @@ namespace marian{ volatile std::sig_atomic_t sigflags_{0}; +volatile std::sig_atomic_t gracefulExitRequested_{0}; bool getSignalFlag(const int sig) { // sig_atomic_t has 32 bits. We don't accommodate signals beyond that. - ABORT_IF(sig >= 32, "Signal {} out of range (must be < 32).", sig); + ABORT_IF(sig >= 32, "Signal out of range (must be < 32, is {}).", sig); return sigflags_ & (1< +#include // SIGNAL HANDLING -// The Marian signal handler setSignalFlag is a general purpose signal handler -// that sets a global flag upon receiving a signal (with SIGNAL No. < 32) in line -// with the recommendations for signal handling in the SEI CERT C Coding Standard, specifically -// - SIG30-C: https://wiki.sei.cmu.edu/confluence/display/c/SIG30-C.+Call+only+asynchronous-safe+functions+within+signal+handlers -// - SIG31-C: https://wiki.sei.cmu.edu/confluence/display/c/SIG31-C.+Do+not+access+shared+objects+in+signal+handlers -// Usage: -// - install the signal handler for a specific signal with signal(SIGNAL, setSignalFlag), -// e.g. signal(SIGTERM, setSignalFlag) -// - check the flag wherever appropriate with getSignalFlag(SIGNAL), -// e.g. getSignalFlag(SIGTERM) -// -// This mechanism is currently used in marian training to ensure a graceful shutdown after receiving -// SIGTERM, saving the current state of training before exiting. This behavior is particularly desirable -// when training on clusters with time limits on computeslots, e.g., on certain clusters managed by slurm. -// Slurm can be asked to send a (custom) warning signal to a process at a certain time priopr to the -// hard end of the time slot. +// The signal handlers (and checkers) here are implemented in line with with the recommendations +// for signal handling in the SEI CERT C Coding Standard, specifically +// +// - SIG30-C: +// https://wiki.sei.cmu.edu/confluence/display/c/SIG30-C.+Call+only+asynchronous-safe+functions+within+signal+handlers +// +// - SIG31-C: +// https://wiki.sei.cmu.edu/confluence/display/c/SIG31-C.+Do+not+access+shared+objects+in+signal+handlers +// +// The exact behavior of 'graceful exit' depends on the application; for training, it means 'save model and exit', +// for a server (not implemented yet): 'block new requests but serve pending requests and then exit'. +// +// Graceful exit for training is useful for training on clusters with time limits on jobs. Slurm, for example, can be +// set up to send a custom signal at a set time before the end of the time slot, giving Marian time to save its current +// state before getting killed. namespace marian { -bool getSignalFlag(int sig); // return true if sig was received, false otherwise -void setSignalFlag(int sig); // custom handler (set flag) for sig -} // end of namespace marian + + +/// Request graceful exit (signal handler) +void requestGracefulExit(const int sig); + +/// Check if graceful exit was requested. +bool gracefulExitRequested(); + +/// General purpose signal handler that simply sets a flag when a signal is received. +// (only for SIGNAL No. < 32). +void setSignalFlag(const int sig); // custom handler (set flag) for sig + +/// Check if a setSignalFlag was triggered for this signal +bool getSignalFlag(const int sig); + +} // End of namespace marian diff --git a/src/data/batch_generator.h b/src/data/batch_generator.h index f3248e66e..fa296f1a4 100644 --- a/src/data/batch_generator.h +++ b/src/data/batch_generator.h @@ -138,8 +138,8 @@ class BatchGenerator : public RNGEngine { size_t sets = 0; while(current_ != data_->end() && maxiBatch->size() < maxSize) { // loop over data - if (getSignalFlag(SIGTERM)) // received SIGTERM, abandon ship ... - return std::deque(); + if (gracefulExitRequested()) // stop generating batches + return std::deque(); maxiBatch->push(*current_); sets = current_->size(); // do not consume more than required for the maxi batch as this causes @@ -164,7 +164,7 @@ class BatchGenerator : public RNGEngine { cachedStatsIter = stats_->begin(); while(!maxiBatch->empty()) { // while there are sentences in the queue - if (getSignalFlag(SIGTERM)) // received SIGTERM, abandon ship ... + if (gracefulExitRequested()) // stop generating batches return std::deque(); // push item onto batch batchVector.push_back(maxiBatch->top()); @@ -251,10 +251,9 @@ class BatchGenerator : public RNGEngine { "If you have changed the training corpus, add --no-restore-corpus to the training command and run it again."); bufferedBatches_ = std::move(futureBufferedBatches_.get()); - if (bufferedBatches_.empty() // i.e., end of Epoch - || getSignalFlag(SIGTERM)) { // process received SIGTERM, abandon ship ... + // stop generating batches at end of epoch or upon graceful exit request: + if (bufferedBatches_.empty() || gracefulExitRequested()) return nullptr; - } fetchBatchesAsync(); // pre-fetch next slew of batches in separate thread } auto batch = bufferedBatches_.front(); diff --git a/src/training/scheduler.h b/src/training/scheduler.h index 2ec9f1ab0..47cc44853 100755 --- a/src/training/scheduler.h +++ b/src/training/scheduler.h @@ -155,7 +155,7 @@ class Scheduler : public TrainingObserver { } bool keepGoing() { - if(getSignalFlag(SIGTERM)) // received signal SIGERM => exit gracefully + if(gracefulExitRequested()) // via SIGTERM return false; // stop if it reached the maximum number of epochs @@ -189,13 +189,12 @@ class Scheduler : public TrainingObserver { void started() { LOG(info, "Training started"); } void finished() { - if (getSignalFlag(SIGTERM)) - LOG(info, "Training interrupted (SIGTERM)."); + if (gracefulExitRequested()) + LOG(info, "Training interrupted (via signal)."); else LOG(info, "Training finished"); } - void addValidator(Ptr validator) { validators_.push_back(validator); @@ -221,8 +220,9 @@ class Scheduler : public TrainingObserver { void validate(const std::vector>& graphs, bool isFinal = false) { // Do not validate if already validated (for instance, after the model is - // loaded) or if validation is scheduled for another update, or when signal SIGTERM was received - if(getSignalFlag(SIGTERM) // SIGTERM was received + // loaded) or if validation is scheduled for another update, or when a + // graceful shutdown was requested via --sig{term|usr1|usr2}. + if(gracefulExitRequested() // signal requesting graceful exit (save model and exit) was received || state_->validated // already validated (in resumed training, for example) || (!state_->enteredNewPeriodOf(options_->get("valid-freq")) && !isFinal)) // not now return; diff --git a/src/training/training.h b/src/training/training.h index 9130bc743..ceb5bb982 100644 --- a/src/training/training.h +++ b/src/training/training.h @@ -16,6 +16,7 @@ template class Train : public ModelTask { private: Ptr options_; + void installCustomSignalHandlers(); public: Train(Ptr options) : options_(options) {} @@ -77,11 +78,8 @@ class Train : public ModelTask { bool restored = !options_->get("no-restore-corpus") && batchGenerator->restore(trainState); - // Install custom handler for SIGTERM, to allow for a graceful - // shutdown that saves the current state of training before exiting. - // This signal handler simply sets a flag that can be checked from - // everywhere (getSignalFLAG(SIGTERM); #include common/signal_handling.h) - signal(SIGTERM, setSignalFlag); + // We only want custom behavior once training starts. + installCustomSignalHandlers(); // -- main training loop scheduler->started(); @@ -113,4 +111,17 @@ class Train : public ModelTask { finalizeMPI(std::move(mpi)); } }; + +template +void Train::installCustomSignalHandlers() +{ + const std::string sigTermAction = options_->get("sigterm"); + if (sigTermAction == "graceful") { + LOG(debug, "Enabling graceful shutdown for SIGTERM."); + signal(SIGTERM, requestGracefulExit); + } + else if (sigTermAction != "immediate") + ABORT("Unrecognized value '{}' for --sigterm", sigTermAction); +} + } // namespace marian From bdda7bde007e10c86f8995b5fb0c39d73fe8c9f9 Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Sun, 23 Aug 2020 23:33:50 +0100 Subject: [PATCH 47/62] Updated CHANGELOG.md. --- CHANGELOG.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 945f87200..a17acce3c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,8 +39,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Properly record cmake variables in the cmake build directory instead of the source tree. - Added default "none" for option shuffle in BatchGenerator, so that it works in executables where shuffle is not an option. - Added a few missing header files in shortlist.h and beam_search.h. -- Improved handling for graceful shutdown upon receiving SIGTERM. - SIGTERM now also interrupts batch prefetching, which runs in a separate thread. +- Improved handling for graceful shutdown upon receiving SIGTERM. SIGTERM now also interrupts batch prefetching, which runs in a separate thread. Graceful shutdown can be disabled with --sigterm 'immediate'. ### Changed - Move Simple-WebSocket-Server to submodule From 3c1656df74db0e6c5dfa3f309dfeeb79175f22bc Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Sun, 23 Aug 2020 23:41:07 +0100 Subject: [PATCH 48/62] Fixed comment in scheduler.h. --- src/training/scheduler.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/training/scheduler.h b/src/training/scheduler.h index 47cc44853..2b5460a20 100755 --- a/src/training/scheduler.h +++ b/src/training/scheduler.h @@ -219,10 +219,10 @@ class Scheduler : public TrainingObserver { void validate(const std::vector>& graphs, bool isFinal = false) { - // Do not validate if already validated (for instance, after the model is - // loaded) or if validation is scheduled for another update, or when a - // graceful shutdown was requested via --sig{term|usr1|usr2}. - if(gracefulExitRequested() // signal requesting graceful exit (save model and exit) was received + // Do not validate if already validated (for instance, after the model is loaded) + // or if validation is scheduled for another update, or when a graceful shutdown + // was requested. + if(gracefulExitRequested() || state_->validated // already validated (in resumed training, for example) || (!state_->enteredNewPeriodOf(options_->get("valid-freq")) && !isFinal)) // not now return; From 30cf713fcd7b909a4f097754fd944fc9298c6135 Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Mon, 24 Aug 2020 00:30:16 +0100 Subject: [PATCH 49/62] Squashed commit of the following: commit 3c1656df74db0e6c5dfa3f309dfeeb79175f22bc Author: Ulrich Germann Date: Sun Aug 23 23:41:07 2020 +0100 Fixed comment in scheduler.h. commit bdda7bde007e10c86f8995b5fb0c39d73fe8c9f9 Author: Ulrich Germann Date: Sun Aug 23 23:33:50 2020 +0100 Updated CHANGELOG.md. commit 3bcd5f85b5011fc1239a1950c73617523973eb7c Merge: 5d80ab4 a21e48f Author: Ulrich Germann Date: Sun Aug 23 23:29:44 2020 +0100 Merge remote-tracking branch 'origin' into ug-graceful-shutdown commit 5d80ab445402ce39ecd983c68c0619e4295ed663 Author: Ulrich Germann Date: Wed Aug 19 13:39:11 2020 +0100 Configurable signal handlers for SIGTERM - Adds custom signal handler option for SIGTERM (--sigterm) to marian training. - default behavior: 'graceful' => graceful exit (save model, then exit) - alternative: 'immediate' => exit immediately without saving model commit 9ab0be51c6cbfead5d2889a859e30decf08ac534 Merge: 0a0b83b d4102cb Author: Ulrich Germann Date: Tue Aug 18 19:39:51 2020 +0100 Merge branch 'ug-graceful-shutdown' of https://github.com/marian-nmt/marian-dev into ug-graceful-shutdown commit 0a0b83b4fb987215096329aae919e77609208d40 Merge: 75459e3 3aed914 Author: Ulrich Germann Date: Tue Aug 18 18:12:17 2020 +0100 Merge branch 'master' into ug-graceful-shutdown commit d4102cbfe4ff3676bbaaf3cc3f4d670c3c835413 Author: Ulrich Germann Date: Wed Aug 5 16:33:30 2020 +0100 Update comments in signal_handling.h commit 521560b4b264c475fc1f12278afae173b4a33030 Author: Ulrich Germann Date: Tue Aug 4 01:58:19 2020 +0100 Update signal_handling.h Edited comments. commit 0a406d867288d44945f47403009bb2a517123925 Author: Ulrich Germann Date: Tue Aug 4 00:41:30 2020 +0100 Update batch_generator.h Frank doesn't like curly braces. commit 286b980155d7986db8df40779a21071e85e3154b Author: Ulrich Germann Date: Tue Aug 4 00:07:22 2020 +0100 Update batch_generator.h Return empty deque after SIGTERM. This makes no practical difference whatsoever but was requested by the code reviewers. commit 75459e373d36deee4131ed505182a3352b325937 Merge: 27cba8f 3dc0795 Author: Ulrich Germann Date: Sun Aug 2 21:41:52 2020 +0100 Merge branch 'ug-graceful-shutdown' of https://github.com/marian-nmt/marian-dev into ug-graceful-shutdown commit 27cba8fcd309d9d84e603c99fbb785b001c40288 Author: Ulrich Germann Date: Sun Aug 2 21:38:36 2020 +0100 Fix space after comma. commit fdc519a23c6f82d046ebf95ecbee8690bc8bd9c7 Merge: f2d9f1e c944633 Author: Ulrich Germann Date: Sun Aug 2 21:29:09 2020 +0100 Merge branch 'master' into ug-graceful-shutdown commit 3dc0795230cb0abf83e705b0da203b33312209df Merge: 9a7e3a0 4475787 Author: Roman Grundkiewicz Date: Sun Jul 26 15:01:33 2020 +0100 Merge branch 'master' into ug-graceful-shutdown commit 9a7e3a05189df3cb6efb82dee307c2d5f6b94a55 Merge: 69f8192 b28905a Author: Ulrich Germann Date: Sun Jul 26 09:59:18 2020 +0100 Merge branch 'master' into ug-graceful-shutdown commit 69f8192af84e5cbc031e28e1b0703c6ce0fd736b Author: Ulrich Germann Date: Sun Jul 26 09:12:59 2020 +0100 Update training.h Insert missing space in line 84, responding to @emjotde's 'nit'. commit f2d9f1e7dacf26260e109e439e0f469887bac4dc Merge: d2d3563 ae1dd47 Author: Ulrich Germann Date: Wed Jun 17 16:02:08 2020 +0100 Merge branch 'ug-graceful-shutdown' of https://github.com/marian-nmt/marian-dev into ug-graceful-shutdown commit d2d35639c415c915092622cb6b0e01d27a8155d8 Author: Ulrich Germann Date: Fri May 22 15:46:13 2020 +0100 Post-rebase fixes. commit 2f2a00b524b5a98817d0a8d422efd3b776646f96 Author: Roman Grundkiewicz Date: Mon Apr 27 10:34:10 2020 +0100 Update Simple-WebSocket-Server and move it to submodules (#639) * Fix server build with current boost, move simple-websocket-server to submodule * Change submodule to marian-nmt/Simple-WebSocket-Server * Update submodule simple-websocket-server Co-authored-by: Gleb Tv commit 1312c18a21833dde303dc9fdfd3e638a7249c0e2 Author: Marcin Junczys-Dowmunt Date: Mon Apr 13 17:31:06 2020 -0700 update changelog and version commit 65c9c449a7179574a5a7a3580327597a540704c3 Author: Roman Grundkiewicz Date: Sun Apr 12 18:56:11 2020 +0100 Support relative paths in shortlist and sqlite options (#612) * Refactorize processPaths * Fix relative paths for shortlist and sqlite options * Rename InterpolateEnvVars to interpolateEnvVars * Update CHANGELOG commit 17167dd9fb455f0fb334c3ca681275cc733f1f88 Author: Marcin Junczys-Dowmunt Date: Sat Apr 11 09:45:57 2020 -0700 Fix 0 * nan behavior due to using -O3 instead of -OFast (#630) * fix 0 * nan behavior in concatention * bump patch * change epsilon to margin commit 709522c788b63cd22fe95d950c90519b331b5e58 Author: Roman Grundkiewicz Date: Sat Apr 11 16:04:20 2020 +0100 Fix TSV training with mini-batch-fit after the last merge commit 5ce67c6f31c58f47b5a0a53534e74d66acd91bee Author: Marcin Junczys-Dowmunt Date: Fri Apr 10 15:27:34 2020 -0700 use float values for catch::Approx commit b06531dff96b75f5478e68cedf040f9fa2fc0895 Author: Marcin Junczys-Dowmunt Date: Fri Apr 10 13:53:21 2020 -0700 actually save the merge file commit 98dff9d26ba0091848f83d91ca25960074977a77 Author: Roman Grundkiewicz Date: Fri Apr 10 21:01:56 2020 +0100 Support tab-separated inputs (#617) * Add basic support for TSV inputs * Fix mini-batch-fit for TSV inputs * Abort if shuffling data from stdin * Fix terminating training with data from STDIN * Allow creating vocabs from TSV files * Add comments; clean creation of vocabs from TSV files * Guess --tsv-size based on the model type * Add shortcut for STDIN inputs * Rename --tsv-size to --tsv-fields * Allow only one 'stdin' in --train-sets * Properly create separate vocabularies from a TSV file * Clearer logging message * Add error message for wrong number of valid sets if --tsv is used * Use --no-shuffle instead of --shuffle in the error message * Fix continuing training from STDIN * Update CHANGELOG * Support both 'stdin' and '-' * Guess --tsv-fields from dim-vocabs if special:model.yml available * Update error messages * Move variable outside the loop * Refactorize utils::splitTsv; add unit tests * Support '-' as stdin; refactorize; add comments * Abort if excessive field(s) in the TSV input * Add a TODO on passing one vocab with fully-tied embeddings * Remove the unit test with excessive tab-separated fields commit 128e1fc19afbe93bbe1d80139e2f5853c6e9060f Author: Young Jin Kim Date: Fri Mar 27 21:44:31 2020 +0000 Merged PR 12243: For int8 quantized model, use int8 quantization for encoders as well For int8 quantized model, use int8 quantization for encoders as well. The quality difference between fp16 encoder and int8 encoder is small, but they have quite amount of speed difference. commit 63006db5ac9359d3074a1b5753765a0417081950 Author: Young Jin Kim Date: Wed Mar 25 02:52:17 2020 +0000 Merged PR 11831: Change the weight matrix quantization to use 7-bit min/max quantization to avoid overflow 1. Change the weight matrix quantization to use 7-bit min/max quantization -> This resolves all the overflow issue, because weight and activations are quantized by min/max range. 2. Clip fp16 quantization to avoid overflow 3. Fix windows build errors (cmake options, vcproj file) 4. int8 pack model (encoder -> fp16) commit 9cd162307456d66a235d4e314492444b5849b96d Author: Ulrich Germann Date: Mon Apr 6 12:49:23 2020 +0100 Bug fix: better handling of SIGTERM for graceful shutdown during training. Prior to this bug fix, BatchGenerator::fetchBatches, which runs in a separate thread, would ignore SIGTERM during training (training uses a custom signal handler for SIGTERM, which simply sets a global flag, to enable graceful shutdown (i.e., save models and current state of training before shutting down). The changes in this commit also facilitate custom handling of other signals in the future by providing a general singal handler for all signals with a signal number below 32 (setSignalFlag) and a generic flag checking function (getSignalFlag(sig)) for checking such flags. commit ae1dd47878e9de406a5823c8691e9fa4a56a495a Author: Roman Grundkiewicz Date: Sun May 17 11:34:18 2020 +0100 Update submodule regression-tests commit 1603d2fe2a653fadd5342d009f3f71d550ee8a60 Author: Marcin Junczys-Dowmunt Date: Thu May 14 08:00:41 2020 -0700 update version commit 9ae1951fe2bf312fb573b2b18b6f2167de242e0c Author: Nikolay Bogoychev Date: Thu May 14 15:55:27 2020 +0100 Batched gemm (#633) * Use cblas_sgemm_batch when available * Merge with master, add comments and describe contribution commit 3f7b459d18e5fdc12e44122bef9b8807ec0554ac Author: Roman Grundkiewicz Date: Mon Apr 27 10:34:10 2020 +0100 Update Simple-WebSocket-Server and move it to submodules (#639) * Fix server build with current boost, move simple-websocket-server to submodule * Change submodule to marian-nmt/Simple-WebSocket-Server * Update submodule simple-websocket-server Co-authored-by: Gleb Tv commit 342db58b7f25430c21563d414180f8620123cb59 Author: Roman Grundkiewicz Date: Sun Apr 26 16:43:36 2020 +0100 Update submodule regression-tests commit 59dad14ed1a1657b4d1cda9756aa08ae2bea70e5 Author: Kenneth Heafield Date: Thu Apr 16 11:15:42 2020 +0100 python3 shebang from #620 (#621) * python3 shebang from #620 * Add changelog entry for python3 change commit ce94fe989243d7aa1d5f445e1181ea2fcbc7c7a2 Author: Marcin Junczys-Dowmunt Date: Mon Apr 13 17:31:06 2020 -0700 update changelog and version commit bc8b6fa162b0840387e195cc3073680bbf854862 Author: Martin Junczys-Dowmunt Date: Tue Apr 14 00:28:44 2020 +0000 Merged PR 12442: cherry pick a few improvements/fixes from Frank's branch Cherry pick a few improvements/fixes from Frank's branch * Adds Frank's fix for label-based mini-batch sizing from Frank's current experimental branch. * Also copies minor improvements and a few comments. commit 34bc47cd3df7ae74013604abcbd4dea5017fe261 Author: Roman Grundkiewicz Date: Sun Apr 12 19:14:03 2020 +0100 Dump version commit 7bf486ad61232b7d0294f8d8a12eca72547e0e97 Author: Roman Grundkiewicz Date: Sun Apr 12 18:58:33 2020 +0100 Fix Iris example on CPU (#623) commit 733cb505bc7353635ee02fdddc7eb9b6465d976b Author: Roman Grundkiewicz Date: Sun Apr 12 18:56:11 2020 +0100 Support relative paths in shortlist and sqlite options (#612) * Refactorize processPaths * Fix relative paths for shortlist and sqlite options * Rename InterpolateEnvVars to interpolateEnvVars * Update CHANGELOG commit 93a27dcdd25d7da126ebda5290c579be7bf68974 Author: Roman Grundkiewicz Date: Sat Apr 11 18:47:17 2020 +0100 Update submodule regression-tests commit 0ba438c463b32831eeae3901b28da0d0cc5bf146 Author: Marcin Junczys-Dowmunt Date: Sat Apr 11 09:45:57 2020 -0700 Fix 0 * nan behavior due to using -O3 instead of -OFast (#630) * fix 0 * nan behavior in concatention * bump patch * change epsilon to margin commit c18fc71e8cb3a7e8bed1c5a84b66fd22d2b6843e Author: Marcin Junczys-Dowmunt Date: Sat Apr 11 09:23:56 2020 -0700 fix 0 * nan behavior in concatention commit 855c94a55daa547041a9bc6dfbe9667022aa5ec5 Author: Roman Grundkiewicz Date: Sat Apr 11 16:06:34 2020 +0100 Update submodule regression-tests commit 4d12ffa96c9335e715a39cfa9017d48d2b8a39a3 Author: Roman Grundkiewicz Date: Sat Apr 11 16:04:20 2020 +0100 Fix TSV training with mini-batch-fit after the last merge commit 09904e0f023c7b4c7334655dd7990e60a5c140f7 Author: Marcin Junczys-Dowmunt Date: Fri Apr 10 15:27:34 2020 -0700 use float values for catch::Approx commit 71cc43a2ff19f15f9cfe5260536881321aba036c Author: Marcin Junczys-Dowmunt Date: Fri Apr 10 13:53:21 2020 -0700 actually save the merge file commit c95676e081da4d488560b38fb0d01bba47272b66 Author: Marcin Junczys-Dowmunt Date: Fri Apr 10 13:50:22 2020 -0700 bump version commit 71e0f0b33fd60cf6a12df148f0675d153e41fd23 Author: Roman Grundkiewicz Date: Fri Apr 10 21:01:56 2020 +0100 Support tab-separated inputs (#617) * Add basic support for TSV inputs * Fix mini-batch-fit for TSV inputs * Abort if shuffling data from stdin * Fix terminating training with data from STDIN * Allow creating vocabs from TSV files * Add comments; clean creation of vocabs from TSV files * Guess --tsv-size based on the model type * Add shortcut for STDIN inputs * Rename --tsv-size to --tsv-fields * Allow only one 'stdin' in --train-sets * Properly create separate vocabularies from a TSV file * Clearer logging message * Add error message for wrong number of valid sets if --tsv is used * Use --no-shuffle instead of --shuffle in the error message * Fix continuing training from STDIN * Update CHANGELOG * Support both 'stdin' and '-' * Guess --tsv-fields from dim-vocabs if special:model.yml available * Update error messages * Move variable outside the loop * Refactorize utils::splitTsv; add unit tests * Support '-' as stdin; refactorize; add comments * Abort if excessive field(s) in the TSV input * Add a TODO on passing one vocab with fully-tied embeddings * Remove the unit test with excessive tab-separated fields commit d0fa14e2640814a02ec8c99ed028c9d3b50744c6 Author: Young Jin Kim Date: Fri Mar 27 21:44:31 2020 +0000 Merged PR 12243: For int8 quantized model, use int8 quantization for encoders as well For int8 quantized model, use int8 quantization for encoders as well. The quality difference between fp16 encoder and int8 encoder is small, but they have quite amount of speed difference. commit 68581a6a4aa64c7e878636fb33873e79a8be202c Author: Young Jin Kim Date: Wed Mar 25 02:52:17 2020 +0000 Merged PR 11831: Change the weight matrix quantization to use 7-bit min/max quantization to avoid overflow 1. Change the weight matrix quantization to use 7-bit min/max quantization -> This resolves all the overflow issue, because weight and activations are quantized by min/max range. 2. Clip fp16 quantization to avoid overflow 3. Fix windows build errors (cmake options, vcproj file) 4. int8 pack model (encoder -> fp16) commit dd065420cbdf98244c3b4383f33411ad2067bc65 Author: Marcin Junczys-Dowmunt Date: Sat Mar 14 09:53:54 2020 -0700 bump version commit 66711b515769fede145335d270597a883d097285 Author: Martin Junczys-Dowmunt Date: Sat Mar 14 00:07:37 2020 +0000 Merged PR 11929: Move around code to make later comparison with FP16 code easier This does not introduce any new functionality, just moves code around, so that future PRs are easier to compare. Moving old GraphGroup code to training/deprecated. Once it is clear there is nothing in there that's worth saving, this will be deleted. Replace -Ofast with -O3 and make sure ffinite-math is turned off. commit 2586af7c1628826a806c4d61a47c0fbc8bd0f599 Author: Ulrich Germann Date: Mon Apr 6 12:49:23 2020 +0100 Bug fix: better handling of SIGTERM for graceful shutdown during training. Prior to this bug fix, BatchGenerator::fetchBatches, which runs in a separate thread, would ignore SIGTERM during training (training uses a custom signal handler for SIGTERM, which simply sets a global flag, to enable graceful shutdown (i.e., save models and current state of training before shutting down). The changes in this commit also facilitate custom handling of other signals in the future by providing a general singal handler for all signals with a signal number below 32 (setSignalFlag) and a generic flag checking function (getSignalFlag(sig)) for checking such flags. commit 8a44759609355fdead4c5cc3e04546636c5b9e1b Merge: 95c65bb 653b13d Author: Ulrich Germann Date: Fri Apr 3 12:37:13 2020 +0100 Merge branch 'ug-graceful-shutdown' of https://github.com/marian-nmt/marian-dev into ug-graceful-shutdown commit 653b13d687cf0ece3823b11fd07accf1574edbc4 Author: Ulrich Germann Date: Fri Nov 22 22:59:38 2019 +0000 Added explanatory comment about exiting marian_train with non-zero status after SIGTERM. commit 73bdb1f7a9612410a55db67879db417ad2b64ac7 Author: Ulrich Germann Date: Tue Nov 19 22:49:25 2019 +0000 Return exit code 15 (SIGTERM) after SIGTERM. When marian receives signal SIGTERM and exits gracefully (save model & exit), it should then exit with a non-zero exit code, to signal to any parent process that it did not exit "naturally". --- CHANGELOG.md | 1 + src/CMakeLists.txt | 2 +- src/command/marian_train.cpp | 11 ++++----- src/common/config_parser.cpp | 12 +++++++++- src/common/signal_handling.cpp | 33 ++++++++++++++++++++++++++ src/common/signal_handling.h | 39 ++++++++++++++++++++++++++++++ src/data/batch_generator.h | 9 +++++-- src/training/scheduler.cpp | 43 ---------------------------------- src/training/scheduler.h | 19 +++++++-------- src/training/training.h | 17 ++++++++++++++ 10 files changed, 122 insertions(+), 64 deletions(-) create mode 100644 src/common/signal_handling.cpp create mode 100644 src/common/signal_handling.h delete mode 100644 src/training/scheduler.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index cc948bac1..a17acce3c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Properly record cmake variables in the cmake build directory instead of the source tree. - Added default "none" for option shuffle in BatchGenerator, so that it works in executables where shuffle is not an option. - Added a few missing header files in shortlist.h and beam_search.h. +- Improved handling for graceful shutdown upon receiving SIGTERM. SIGTERM now also interrupts batch prefetching, which runs in a separate thread. Graceful shutdown can be disabled with --sigterm 'immediate'. ### Changed - Move Simple-WebSocket-Server to submodule diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6cb6bea42..f95941ecd 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -25,6 +25,7 @@ add_library(marian STATIC common/filesystem.cpp common/file_stream.cpp common/file_utils.cpp + common/signal_handling.cpp common/types.cpp data/alignment.cpp @@ -99,7 +100,6 @@ add_library(marian STATIC training/graph_group_singleton.cpp training/validator.cpp training/communicator.cpp - training/scheduler.cpp # this is only compiled to catch build errors, but not linked microsoft/quicksand.cpp diff --git a/src/command/marian_train.cpp b/src/command/marian_train.cpp index d1978fab4..f9c6492d9 100644 --- a/src/command/marian_train.cpp +++ b/src/command/marian_train.cpp @@ -1,6 +1,7 @@ #include #include "marian.h" +#include "common/signal_handling.h" #include "training/graph_group_async.h" #include "training/graph_group_singleton.h" #include "training/graph_group_sync.h" @@ -42,14 +43,12 @@ int main(int argc, char** argv) { New>(options)->run(); } } - - // If we exit due to SIGTERM, exit with 128 + the signal number, as suggested - // for bash in http://tldp.org/LDP/abs/html/exitcodes.html. This allows parent + // If we exit due to a graceful exit request via SIGTERM, exit with 128 + SIGTERM, + // as suggested for bash in http://tldp.org/LDP/abs/html/exitcodes.html. This allows parent // scripts to determine if training terminated naturally or via SIGTERM. - // Whith this approach we can accommodate additional signals in the future. - // An alternative would be to return 124, which is what the timeout command + // An alternative would be to exit with code 124, which is what the timeout command // returns for timeout -s SIGTERM ...., because exiting after SIGTERM // is not technically a fatal error (which is what the 128+x convention usually // stands for). - return getSigtermFlag() ? (128 + SIGTERM) : 0; + exit(getSignalFlag(SIGTERM) ? 128 + SIGTERM : EXIT_SUCCESS); } diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index a44a00826..7f825eea5 100755 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -324,13 +324,23 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) { "Dropout for transformer attention (0 = no dropout)"); cli.add("--transformer-dropout-ffn", "Dropout for transformer filter (0 = no dropout)"); + } cli.switchGroup(previous_group); // clang-format on } void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) { - auto previous_group = cli.switchGroup("Training options"); + auto previous_group = cli.switchGroup("Signal Handling"); + // --sigterm is deliberately not a boolean, to allow for a consistent + // pattern of specifying custom signal handling in the future. + // (e.g., dump model but continue training upon SIGUSR1, or report current + // training status upon SIGINFO.) + cli.add("--sigterm", + "What to do with SIGTERM: 'graceful' => save and exit (default); " + "'immediate' => exit immediately.", "graceful"); + + cli.switchGroup("Training options"); // clang-format off cli.add("--cost-type", // @TODO: rename to loss-type "Optimization criterion: ce-mean, ce-mean-words, ce-sum, perplexity", "ce-mean"); diff --git a/src/common/signal_handling.cpp b/src/common/signal_handling.cpp new file mode 100644 index 000000000..889162190 --- /dev/null +++ b/src/common/signal_handling.cpp @@ -0,0 +1,33 @@ +#include "common/logging.h" +#include "signal_handling.h" + +// We use signal() here instead of the usual strong recommendation for +// using sigaction, which apparently is not available for Windows (cf. +// https://stackoverflow.com/questions/231912/what-is-the-difference-between-sigaction-and-signal). + +namespace marian{ +volatile std::sig_atomic_t sigflags_{0}; +volatile std::sig_atomic_t gracefulExitRequested_{0}; + +bool getSignalFlag(const int sig) { + // sig_atomic_t has 32 bits. We don't accommodate signals beyond that. + ABORT_IF(sig >= 32, "Signal out of range (must be < 32, is {}).", sig); + return sigflags_ & (1< +#include + +// SIGNAL HANDLING + +// The signal handlers (and checkers) here are implemented in line with with the recommendations +// for signal handling in the SEI CERT C Coding Standard, specifically +// +// - SIG30-C: +// https://wiki.sei.cmu.edu/confluence/display/c/SIG30-C.+Call+only+asynchronous-safe+functions+within+signal+handlers +// +// - SIG31-C: +// https://wiki.sei.cmu.edu/confluence/display/c/SIG31-C.+Do+not+access+shared+objects+in+signal+handlers +// +// The exact behavior of 'graceful exit' depends on the application; for training, it means 'save model and exit', +// for a server (not implemented yet): 'block new requests but serve pending requests and then exit'. +// +// Graceful exit for training is useful for training on clusters with time limits on jobs. Slurm, for example, can be +// set up to send a custom signal at a set time before the end of the time slot, giving Marian time to save its current +// state before getting killed. + +namespace marian { + + +/// Request graceful exit (signal handler) +void requestGracefulExit(const int sig); + +/// Check if graceful exit was requested. +bool gracefulExitRequested(); + +/// General purpose signal handler that simply sets a flag when a signal is received. +// (only for SIGNAL No. < 32). +void setSignalFlag(const int sig); // custom handler (set flag) for sig + +/// Check if a setSignalFlag was triggered for this signal +bool getSignalFlag(const int sig); + +} // End of namespace marian diff --git a/src/data/batch_generator.h b/src/data/batch_generator.h index 88d3efb92..54fdf38da 100644 --- a/src/data/batch_generator.h +++ b/src/data/batch_generator.h @@ -1,6 +1,7 @@ #pragma once #include "common/options.h" +#include "common/signal_handling.h" #include "data/batch_stats.h" #include "data/rng_engine.h" #include "training/training_state.h" @@ -136,6 +137,8 @@ class BatchGenerator : public RNGEngine { } size_t sets = 0; while(current_ != data_->end() && maxiBatch->size() < maxSize) { // loop over data + if (gracefulExitRequested()) // stop generating batches + return std::deque(); maxiBatch->push(*current_); sets = current_->size(); // do not consume more than required for the maxi batch as this causes @@ -161,6 +164,8 @@ class BatchGenerator : public RNGEngine { if (stats_) cachedStatsIter = stats_->begin(); while(!maxiBatch->empty()) { // while there are sentences in the queue + if (gracefulExitRequested()) // stop generating batches + return std::deque(); // push item onto batch batchVector.push_back(maxiBatch->top()); maxiBatch->pop(); // fetch next-shortest @@ -249,7 +254,7 @@ class BatchGenerator : public RNGEngine { "If you have changed the training corpus, add --no-restore-corpus to the training command and run it again."); bufferedBatches_ = std::move(futureBufferedBatches_.get()); // if bg thread returns an empty swath, we hit the end of the epoch - if (bufferedBatches_.empty()) { + if (bufferedBatches_.empty() || gracefulExitRequested()) { return nullptr; } // and kick off the next bg operation @@ -257,7 +262,7 @@ class BatchGenerator : public RNGEngine { } else { // don't spawn any threads, i.e. batch fetching is blocking. bufferedBatches_ = fetchBatches(); // if bufferedBatches is empty we hit the end of the epoch - if (bufferedBatches_.empty()) { + if (bufferedBatches_.empty() || gracefulExitRequested()) { return nullptr; } } diff --git a/src/training/scheduler.cpp b/src/training/scheduler.cpp deleted file mode 100644 index 4c30cb04e..000000000 --- a/src/training/scheduler.cpp +++ /dev/null @@ -1,43 +0,0 @@ -#include "scheduler.h" -#include -#include - -namespace marian { - -// SIGNAL HANDLING, see scheduler.cpp for definitions -// Currently, only the following is handled by a custom signal handler: -// SIGTERM: When SIGTERM is received, the global (static member) flag sigterm_ (false by default) is set to true -// by signalHandler(). When sigterm_ is true, keepGoing() returns false, and the current state of training models -// is saved prior to exiting. -// This functionality is helpful when training on clusters with time limits on compute slots, e.g., on s -// clusters managed by slurm. Slurm can be asked to sending a (custom) warning signal to a process at a given -// point in time prior to the hard "time's up". - -bool sigterm_{false}; // flag signalling that SIGTERM has been received false by default, set to true by signalHandler(SIGTERM) - -void signalHandler(int sig) { - // Note: sys_siglist[sig] or stdsignal() describe the effect (e.g., - // 'Terminated' rather than provide the signal name (which are #define(s) - // in signal.h), so we have to do custom log messages here. - switch (sig) { - case SIGTERM: // save models and exit - LOG(info, "[training] Scheduler received signal SIGTERM"); // @TODO: figure out if this is safe. The logs are global and thread-safe, so should be OK? - sigterm_ = true; - break; - default: - ABORT("No action defined for signal {}", sig); - } -} - -// installs signalHandler() for select signals (currently only SIGTERM) -void installSignalHandlers() { - // TODO: use sigaction instead of signal, - // cf. https://stackoverflow.com/questions/231912/what-is-the-difference-between-sigaction-and-signal - signal(SIGTERM, signalHandler); -} - -bool getSigtermFlag() { - return sigterm_; -} - -} diff --git a/src/training/scheduler.h b/src/training/scheduler.h index 8c8701cac..2b5460a20 100755 --- a/src/training/scheduler.h +++ b/src/training/scheduler.h @@ -1,6 +1,7 @@ #pragma once #include "common/options.h" +#include "common/signal_handling.h" #include "training/training_state.h" #include "training/validator.h" #include "training/communicator.h" @@ -8,9 +9,6 @@ namespace marian { -bool getSigtermFlag(); -void installSignalHandlers(); - class Scheduler : public TrainingObserver { private: Ptr options_; @@ -154,11 +152,10 @@ class Scheduler : public TrainingObserver { : options_(options), state_(state) { ABORT_IF(state_->factor != 1, "state.factor unexpectedly not 1 at this point??"); updateLearningRate(*state); - installSignalHandlers(); } bool keepGoing() { - if(getSigtermFlag()) // received signal SIGERM => exit gracefully + if(gracefulExitRequested()) // via SIGTERM return false; // stop if it reached the maximum number of epochs @@ -192,13 +189,12 @@ class Scheduler : public TrainingObserver { void started() { LOG(info, "Training started"); } void finished() { - if (getSigtermFlag()) - LOG(info, "Training interrupted (SIGTERM)."); + if (gracefulExitRequested()) + LOG(info, "Training interrupted (via signal)."); else LOG(info, "Training finished"); } - void addValidator(Ptr validator) { validators_.push_back(validator); @@ -223,9 +219,10 @@ class Scheduler : public TrainingObserver { void validate(const std::vector>& graphs, bool isFinal = false) { - // Do not validate if already validated (for instance, after the model is - // loaded) or if validation is scheduled for another update, or when signal SIGTERM was received - if(getSigtermFlag() // SIGTERM was received + // Do not validate if already validated (for instance, after the model is loaded) + // or if validation is scheduled for another update, or when a graceful shutdown + // was requested. + if(gracefulExitRequested() || state_->validated // already validated (in resumed training, for example) || (!state_->enteredNewPeriodOf(options_->get("valid-freq")) && !isFinal)) // not now return; diff --git a/src/training/training.h b/src/training/training.h index 5a2be7635..ceb5bb982 100644 --- a/src/training/training.h +++ b/src/training/training.h @@ -16,6 +16,7 @@ template class Train : public ModelTask { private: Ptr options_; + void installCustomSignalHandlers(); public: Train(Ptr options) : options_(options) {} @@ -77,6 +78,9 @@ class Train : public ModelTask { bool restored = !options_->get("no-restore-corpus") && batchGenerator->restore(trainState); + // We only want custom behavior once training starts. + installCustomSignalHandlers(); + // -- main training loop scheduler->started(); while(scheduler->keepGoing()) { @@ -107,4 +111,17 @@ class Train : public ModelTask { finalizeMPI(std::move(mpi)); } }; + +template +void Train::installCustomSignalHandlers() +{ + const std::string sigTermAction = options_->get("sigterm"); + if (sigTermAction == "graceful") { + LOG(debug, "Enabling graceful shutdown for SIGTERM."); + signal(SIGTERM, requestGracefulExit); + } + else if (sigTermAction != "immediate") + ABORT("Unrecognized value '{}' for --sigterm", sigTermAction); +} + } // namespace marian From 6a753d688267f1f8603ab573befb6b2dde32c9de Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Wed, 26 Aug 2020 21:09:17 +0100 Subject: [PATCH 50/62] Update config_parser.cpp Cosmetic fixes as per code review. - removed superfluous empty line - but default value for --sigterm on a new line --- src/common/config_parser.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index 7f825eea5..c2f3efd91 100755 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -324,7 +324,6 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) { "Dropout for transformer attention (0 = no dropout)"); cli.add("--transformer-dropout-ffn", "Dropout for transformer filter (0 = no dropout)"); - } cli.switchGroup(previous_group); // clang-format on @@ -338,7 +337,8 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) { // training status upon SIGINFO.) cli.add("--sigterm", "What to do with SIGTERM: 'graceful' => save and exit (default); " - "'immediate' => exit immediately.", "graceful"); + "'immediate' => exit immediately.", + "graceful"); cli.switchGroup("Training options"); // clang-format off From deafe6a0ad79e55690e66b11b0da15b3e4bbe94a Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Wed, 26 Aug 2020 21:32:46 +0100 Subject: [PATCH 51/62] Update config_parser.cpp Move option --segterm to General Options for training, as suggested by @snukky. --- src/common/config_parser.cpp | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index c2f3efd91..45526dcff 100755 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -143,6 +143,16 @@ void ConfigParser::addOptionsGeneral(cli::CLIWrapper& cli) { cli.add("--dump-config", "Dump current (modified) configuration to stdout and exit. Possible values: full, minimal, expand") ->implicit_val("full"); + if(mode_ == cli::mode::training) { + // --sigterm is deliberately not a boolean, to allow for a consistent + // pattern of specifying custom signal handling in the future. + // (e.g., dump model but continue training upon SIGUSR1, or report current + // training status upon SIGINFO.) + cli.add("--sigterm", + "What to do with SIGTERM: 'graceful' => save and exit (default); " + "'immediate' => exit immediately.", + "graceful"); + } // clang-format on } @@ -330,17 +340,7 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) { } void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) { - auto previous_group = cli.switchGroup("Signal Handling"); - // --sigterm is deliberately not a boolean, to allow for a consistent - // pattern of specifying custom signal handling in the future. - // (e.g., dump model but continue training upon SIGUSR1, or report current - // training status upon SIGINFO.) - cli.add("--sigterm", - "What to do with SIGTERM: 'graceful' => save and exit (default); " - "'immediate' => exit immediately.", - "graceful"); - - cli.switchGroup("Training options"); + auto previous_group = cli.switchGroup("Training options"); // clang-format off cli.add("--cost-type", // @TODO: rename to loss-type "Optimization criterion: ce-mean, ce-mean-words, ce-sum, perplexity", "ce-mean"); From 8e41f2226fbca50800836124b7594404cbfef650 Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Wed, 26 Aug 2020 21:36:58 +0100 Subject: [PATCH 52/62] Update signal_handling.cpp Remove log message from signal handler. Note that the log still records that training was interrupted by a signal both in the log message in Scheduler::finished() and in the exit code at the end of marian_train.cpp. --- src/common/signal_handling.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/common/signal_handling.cpp b/src/common/signal_handling.cpp index 889162190..8d3dd71e3 100644 --- a/src/common/signal_handling.cpp +++ b/src/common/signal_handling.cpp @@ -18,7 +18,6 @@ bool getSignalFlag(const int sig) { void requestGracefulExit(int sig) { setSignalFlag(sig); // keep track of triggering signal gracefulExitRequested_ = 1; // set flag to exit gracefully - LOG(debug, "Graceful exit requested via signal {}.", sig); } bool gracefulExitRequested() { From 69b10cfe0f0f1751e0b73a51168ae507088c83af Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Wed, 26 Aug 2020 21:46:24 +0100 Subject: [PATCH 53/62] Update config_parser.cpp Fix formatting of --sigterm option specification. --- src/common/config_parser.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index 45526dcff..f90736d29 100755 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -149,9 +149,8 @@ void ConfigParser::addOptionsGeneral(cli::CLIWrapper& cli) { // (e.g., dump model but continue training upon SIGUSR1, or report current // training status upon SIGINFO.) cli.add("--sigterm", - "What to do with SIGTERM: 'graceful' => save and exit (default); " - "'immediate' => exit immediately.", - "graceful"); + "What to do with SIGTERM: 'graceful' => save and exit (default); 'immediate' => exit immediately.", + "graceful"); } // clang-format on } From 1a5fbbc2fa2b20f9066a18bdb30bec403b559077 Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Thu, 27 Aug 2020 22:32:23 +0100 Subject: [PATCH 54/62] Fix trailing whitespace. --- src/common/config_parser.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index f90736d29..41388d693 100755 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -149,7 +149,7 @@ void ConfigParser::addOptionsGeneral(cli::CLIWrapper& cli) { // (e.g., dump model but continue training upon SIGUSR1, or report current // training status upon SIGINFO.) cli.add("--sigterm", - "What to do with SIGTERM: 'graceful' => save and exit (default); 'immediate' => exit immediately.", + "What to do with SIGTERM: 'graceful' => save and exit (default); 'immediate' => exit immediately.", "graceful"); } // clang-format on From 506eb71e8dbbff29f2b43f47e0772bf5f1739c91 Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Thu, 27 Aug 2020 23:41:57 +0100 Subject: [PATCH 55/62] Cosmetic cleanup. --- src/common/signal_handling.cpp | 37 ++++++++++++++++++++++++---------- src/common/signal_handling.h | 6 +++--- src/training/training.h | 3 +-- 3 files changed, 30 insertions(+), 16 deletions(-) diff --git a/src/common/signal_handling.cpp b/src/common/signal_handling.cpp index 8d3dd71e3..3dca42a07 100644 --- a/src/common/signal_handling.cpp +++ b/src/common/signal_handling.cpp @@ -1,17 +1,37 @@ #include "common/logging.h" #include "signal_handling.h" -// We use signal() here instead of the usual strong recommendation for -// using sigaction, which apparently is not available for Windows (cf. -// https://stackoverflow.com/questions/231912/what-is-the-difference-between-sigaction-and-signal). +// The simplest (and recommended) way to handle signals is to simply set a flag +// in the signal handler and check that flag later. +// +// We provide setSignalFlag as the most generic signal handler. +// This handler which uses a single sig_atomic_t as a bit field. +// On Linux, sig_atomic_t is equivalent to a signed int, theoretically +// providing 32 binary flags; in practice, most likely signals for which we may +// want to install signal handlers are +// - SIGTERM (15): which by default signals the request for a graceful exit +// (see also: https://qph.fs.quoracdn.net/main-qimg-1180ef2465c309928b02481f02580c6a) +// - SIGUSR1,SIGUSR2 (10,12): signals specifically reserved for custom use +// - SIGINT (2): interrupt from the console +// Just to be safe, we accommodate signals up to signal No. 30. +constexpr int maxSignalForSetSetSignalFlag{30}; + +// Make sure sig_atomic_t is large enough as a bit field for our purposes. +// That said, I'm not aware of any platform where this would be a problem. +static_assert(SIG_ATOMIC_MAX > (1U<= 32, "Signal out of range (must be < 32, is {}).", sig); + ABORT_IF(sig > maxSignalForSetSignalFlag, + "Signal out of range (must be < {}, is {}).", maxSignalForSetSignalFlag, sig); return sigflags_ & (1< -void Train::installCustomSignalHandlers() -{ +void Train::installCustomSignalHandlers(){ const std::string sigTermAction = options_->get("sigterm"); if (sigTermAction == "graceful") { LOG(debug, "Enabling graceful shutdown for SIGTERM."); From 139052a1b08b5daf00c3536dbd0b95c393344d79 Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Fri, 28 Aug 2020 01:01:22 +0100 Subject: [PATCH 56/62] Comment code. --- src/common/signal_handling.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/common/signal_handling.cpp b/src/common/signal_handling.cpp index 3dca42a07..19412cd0b 100644 --- a/src/common/signal_handling.cpp +++ b/src/common/signal_handling.cpp @@ -26,12 +26,22 @@ volatile std::sig_atomic_t sigflags_{0}; volatile std::sig_atomic_t gracefulExitRequested_{0}; void setSignalFlag(int sig) { + // sigflags_ is an int type serving as a bit filed for flags corresponding + // to signals (lower or equeal to maxSignalForSetSignalFlag). We set the + // flag by a binary or (|=) of the bit field and an int value with exactly + // one bit set (s^sig). sigflags_ |= (1< maxSignalForSetSignalFlag, "Signal out of range (must be < {}, is {}).", maxSignalForSetSignalFlag, sig); + // Do bitwise AND between sigflags_ and an int value that has exactly one bit set that + // corresponds to the signal in question. If the bit is set (see setSignalFlag above), + // the bitwise AND will return a non-zero integer, if it is not set, the result will + // be zero. Implicit type conversion from int to bool will convert this into a boolean + // value: true if the signal flag has been set, false otherwise. return sigflags_ & (1< Date: Fri, 28 Aug 2020 01:02:31 +0100 Subject: [PATCH 57/62] Change values for --sigterm in response to code review. --- src/common/config_parser.cpp | 4 ++-- src/training/training.h | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index 41388d693..e698ddcd4 100755 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -149,8 +149,8 @@ void ConfigParser::addOptionsGeneral(cli::CLIWrapper& cli) { // (e.g., dump model but continue training upon SIGUSR1, or report current // training status upon SIGINFO.) cli.add("--sigterm", - "What to do with SIGTERM: 'graceful' => save and exit (default); 'immediate' => exit immediately.", - "graceful"); + "What to do with SIGTERM: save-and-exit or exit-immediately.", + "save-and-exit"); } // clang-format on } diff --git a/src/training/training.h b/src/training/training.h index ff371b77d..01445cf8f 100644 --- a/src/training/training.h +++ b/src/training/training.h @@ -115,11 +115,11 @@ class Train : public ModelTask { template void Train::installCustomSignalHandlers(){ const std::string sigTermAction = options_->get("sigterm"); - if (sigTermAction == "graceful") { - LOG(debug, "Enabling graceful shutdown for SIGTERM."); + if (sigTermAction == "save-and-exit") { + LOG(debug, "Will save before exiting upon SIGTERM."); signal(SIGTERM, requestGracefulExit); } - else if (sigTermAction != "immediate") + else if (sigTermAction != "exit-immediately") ABORT("Unrecognized value '{}' for --sigterm", sigTermAction); } From c7ad4f166b24772200e93b3af4c43cc74abdb692 Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Fri, 28 Aug 2020 01:17:21 +0100 Subject: [PATCH 58/62] Bug fix in variable name. --- src/common/signal_handling.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/common/signal_handling.cpp b/src/common/signal_handling.cpp index 19412cd0b..f412585c5 100644 --- a/src/common/signal_handling.cpp +++ b/src/common/signal_handling.cpp @@ -14,11 +14,11 @@ // - SIGUSR1,SIGUSR2 (10,12): signals specifically reserved for custom use // - SIGINT (2): interrupt from the console // Just to be safe, we accommodate signals up to signal No. 30. -constexpr int maxSignalForSetSetSignalFlag{30}; +constexpr int maxSignalForSetSignalFlag{30}; // Make sure sig_atomic_t is large enough as a bit field for our purposes. // That said, I'm not aware of any platform where this would be a problem. -static_assert(SIG_ATOMIC_MAX > (1U< (1U< Date: Fri, 28 Aug 2020 14:21:59 +0100 Subject: [PATCH 59/62] Update signal_handling.cpp. - Update comments - Explicit conversion from int to bool in functions that return bool. --- src/common/signal_handling.cpp | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/common/signal_handling.cpp b/src/common/signal_handling.cpp index f412585c5..53fc594f1 100644 --- a/src/common/signal_handling.cpp +++ b/src/common/signal_handling.cpp @@ -4,16 +4,22 @@ // The simplest (and recommended) way to handle signals is to simply set a flag // in the signal handler and check that flag later. // -// We provide setSignalFlag as the most generic signal handler. -// This handler which uses a single sig_atomic_t as a bit field. -// On Linux, sig_atomic_t is equivalent to a signed int, theoretically -// providing 32 binary flags; in practice, most likely signals for which we may +// We provide setSignalFlag as the most generic signal handler. This handler uses a +// single sig_atomic_t as a bit field. On Linux, sig_atomic_t is equivalent to a signed int, +// theoretically providing 32 binary flags; in practice, most likely signals for which we may // want to install signal handlers are -// - SIGTERM (15): which by default signals the request for a graceful exit +// - SIGTERM (15): which by default signals the request for a graceful shutdown // (see also: https://qph.fs.quoracdn.net/main-qimg-1180ef2465c309928b02481f02580c6a) -// - SIGUSR1,SIGUSR2 (10,12): signals specifically reserved for custom use +// - SIGUSR1 (10): intended for custom use, default action in Linux is termination +// - SIGUSR2 (12): intended for custom use, default action in Linux is termination // - SIGINT (2): interrupt from the console // Just to be safe, we accommodate signals up to signal No. 30. + +// In addition, we also provide requestGracefulExit() and gracefulExitRequested() as a signal +// handler/checker for graceful shutdown requests (what exactly that means, depends on the +// application; for training, it means save-and-exit, for a server, it might mean block new +// requests, serve bending requests, then exit) that can be installed for arbitrary signals +// (SIGUSR1). constexpr int maxSignalForSetSignalFlag{30}; // Make sure sig_atomic_t is large enough as a bit field for our purposes. @@ -40,9 +46,8 @@ bool getSignalFlag(const int sig) { // Do bitwise AND between sigflags_ and an int value that has exactly one bit set that // corresponds to the signal in question. If the bit is set (see setSignalFlag above), // the bitwise AND will return a non-zero integer, if it is not set, the result will - // be zero. Implicit type conversion from int to bool will convert this into a boolean - // value: true if the signal flag has been set, false otherwise. - return sigflags_ & (1< Date: Fri, 28 Aug 2020 19:27:53 +0100 Subject: [PATCH 60/62] Rename 'graceful exit' to 'save and exit'. --- src/common/signal_handling.cpp | 22 +++++++++------------- src/common/signal_handling.h | 4 ++-- src/data/batch_generator.h | 8 ++++---- src/training/scheduler.h | 6 +++--- src/training/training.h | 2 +- 5 files changed, 19 insertions(+), 23 deletions(-) diff --git a/src/common/signal_handling.cpp b/src/common/signal_handling.cpp index 53fc594f1..dd748b88b 100644 --- a/src/common/signal_handling.cpp +++ b/src/common/signal_handling.cpp @@ -4,22 +4,18 @@ // The simplest (and recommended) way to handle signals is to simply set a flag // in the signal handler and check that flag later. // -// We provide setSignalFlag as the most generic signal handler. This handler uses a -// single sig_atomic_t as a bit field. On Linux, sig_atomic_t is equivalent to a signed int, +// We provide setSignalFlag as the most generic signal handler. This handler uses a +// single sig_atomic_t as a bit field. On Linux, sig_atomic_t is equivalent to a signed int, // theoretically providing 32 binary flags; in practice, most likely signals for which we may // want to install signal handlers are // - SIGTERM (15): which by default signals the request for a graceful shutdown -// (see also: https://qph.fs.quoracdn.net/main-qimg-1180ef2465c309928b02481f02580c6a) // - SIGUSR1 (10): intended for custom use, default action in Linux is termination // - SIGUSR2 (12): intended for custom use, default action in Linux is termination // - SIGINT (2): interrupt from the console // Just to be safe, we accommodate signals up to signal No. 30. -// In addition, we also provide requestGracefulExit() and gracefulExitRequested() as a signal -// handler/checker for graceful shutdown requests (what exactly that means, depends on the -// application; for training, it means save-and-exit, for a server, it might mean block new -// requests, serve bending requests, then exit) that can be installed for arbitrary signals -// (SIGUSR1). +// In addition, we also provide requestSaveAndExit() and saveAndExit() as a signal +// handler/checker for graceful shutdown requests during training. constexpr int maxSignalForSetSignalFlag{30}; // Make sure sig_atomic_t is large enough as a bit field for our purposes. @@ -29,7 +25,7 @@ static_assert(SIG_ATOMIC_MAX > (1U<end() && maxiBatch->size() < maxSize) { // loop over data - if (gracefulExitRequested()) // stop generating batches + if (saveAndExit()) // stop generating batches return std::deque(); maxiBatch->push(*current_); sets = current_->size(); @@ -164,7 +164,7 @@ class BatchGenerator : public RNGEngine { if (stats_) cachedStatsIter = stats_->begin(); while(!maxiBatch->empty()) { // while there are sentences in the queue - if (gracefulExitRequested()) // stop generating batches + if (saveAndExit()) // stop generating batches return std::deque(); // push item onto batch batchVector.push_back(maxiBatch->top()); @@ -254,7 +254,7 @@ class BatchGenerator : public RNGEngine { "If you have changed the training corpus, add --no-restore-corpus to the training command and run it again."); bufferedBatches_ = std::move(futureBufferedBatches_.get()); // if bg thread returns an empty swath, we hit the end of the epoch - if (bufferedBatches_.empty() || gracefulExitRequested()) { + if (bufferedBatches_.empty() || saveAndExit()) { return nullptr; } // and kick off the next bg operation @@ -262,7 +262,7 @@ class BatchGenerator : public RNGEngine { } else { // don't spawn any threads, i.e. batch fetching is blocking. bufferedBatches_ = fetchBatches(); // if bufferedBatches is empty we hit the end of the epoch - if (bufferedBatches_.empty() || gracefulExitRequested()) { + if (bufferedBatches_.empty() || saveAndExit()) { return nullptr; } } diff --git a/src/training/scheduler.h b/src/training/scheduler.h index 2b5460a20..8e405ff30 100755 --- a/src/training/scheduler.h +++ b/src/training/scheduler.h @@ -155,7 +155,7 @@ class Scheduler : public TrainingObserver { } bool keepGoing() { - if(gracefulExitRequested()) // via SIGTERM + if(saveAndExit()) // via SIGTERM return false; // stop if it reached the maximum number of epochs @@ -189,7 +189,7 @@ class Scheduler : public TrainingObserver { void started() { LOG(info, "Training started"); } void finished() { - if (gracefulExitRequested()) + if (saveAndExit()) LOG(info, "Training interrupted (via signal)."); else LOG(info, "Training finished"); @@ -222,7 +222,7 @@ class Scheduler : public TrainingObserver { // Do not validate if already validated (for instance, after the model is loaded) // or if validation is scheduled for another update, or when a graceful shutdown // was requested. - if(gracefulExitRequested() + if(saveAndExit() || state_->validated // already validated (in resumed training, for example) || (!state_->enteredNewPeriodOf(options_->get("valid-freq")) && !isFinal)) // not now return; diff --git a/src/training/training.h b/src/training/training.h index 01445cf8f..d2be8b872 100644 --- a/src/training/training.h +++ b/src/training/training.h @@ -117,7 +117,7 @@ void Train::installCustomSignalHandlers(){ const std::string sigTermAction = options_->get("sigterm"); if (sigTermAction == "save-and-exit") { LOG(debug, "Will save before exiting upon SIGTERM."); - signal(SIGTERM, requestGracefulExit); + signal(SIGTERM, requestSaveAndExit); } else if (sigTermAction != "exit-immediately") ABORT("Unrecognized value '{}' for --sigterm", sigTermAction); From d7044f7b2364d8f7326b1ccfd6f6f720ae2be956 Mon Sep 17 00:00:00 2001 From: Ulrich Germann Date: Fri, 28 Aug 2020 19:36:57 +0100 Subject: [PATCH 61/62] Update CHANGELOG.md. --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a17acce3c..9f5ee08d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,7 +39,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Properly record cmake variables in the cmake build directory instead of the source tree. - Added default "none" for option shuffle in BatchGenerator, so that it works in executables where shuffle is not an option. - Added a few missing header files in shortlist.h and beam_search.h. -- Improved handling for graceful shutdown upon receiving SIGTERM. SIGTERM now also interrupts batch prefetching, which runs in a separate thread. Graceful shutdown can be disabled with --sigterm 'immediate'. +- Improved handling for receiving SIGTERM during training. By default, SIGTERM triggers 'save (now) and exit'. Prior to this fix, batch pre-fetching did not check for this sigal, potentially delaying exit considerably. It now pays attention to that. Also, the default behaviour of save-and-exit can now be disabled on the command line with --sigterm exit-immediately. ### Changed - Move Simple-WebSocket-Server to submodule From 6a4d887416ae1f8cb686bca808509e560d0a861b Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Mon, 31 Aug 2020 19:11:48 -0700 Subject: [PATCH 62/62] rename function name to saveAndExitRequested() --- src/common/signal_handling.cpp | 2 +- src/common/signal_handling.h | 2 +- src/data/batch_generator.h | 8 ++++---- src/training/scheduler.h | 6 +++--- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/common/signal_handling.cpp b/src/common/signal_handling.cpp index dd748b88b..8e3fd9133 100644 --- a/src/common/signal_handling.cpp +++ b/src/common/signal_handling.cpp @@ -51,7 +51,7 @@ void requestSaveAndExit(int sig) { saveAndExit_ = 1; // set flag to exit gracefully } -bool saveAndExit() { +bool saveAndExitRequested() { return saveAndExit_ == 1; } diff --git a/src/common/signal_handling.h b/src/common/signal_handling.h index 170a785e6..25111b2c2 100644 --- a/src/common/signal_handling.h +++ b/src/common/signal_handling.h @@ -27,7 +27,7 @@ namespace marian { void requestSaveAndExit(int sig); /// Check if graceful exit was requested. -bool saveAndExit(); +bool saveAndExitRequested(); /// General purpose signal handler that simply sets a flag when a signal is received. // (only for SIGNAL No. < 32). diff --git a/src/data/batch_generator.h b/src/data/batch_generator.h index f7a87d774..d55e765e3 100644 --- a/src/data/batch_generator.h +++ b/src/data/batch_generator.h @@ -137,7 +137,7 @@ class BatchGenerator : public RNGEngine { } size_t sets = 0; while(current_ != data_->end() && maxiBatch->size() < maxSize) { // loop over data - if (saveAndExit()) // stop generating batches + if (saveAndExitRequested()) // stop generating batches return std::deque(); maxiBatch->push(*current_); sets = current_->size(); @@ -164,7 +164,7 @@ class BatchGenerator : public RNGEngine { if (stats_) cachedStatsIter = stats_->begin(); while(!maxiBatch->empty()) { // while there are sentences in the queue - if (saveAndExit()) // stop generating batches + if (saveAndExitRequested()) // stop generating batches return std::deque(); // push item onto batch batchVector.push_back(maxiBatch->top()); @@ -254,7 +254,7 @@ class BatchGenerator : public RNGEngine { "If you have changed the training corpus, add --no-restore-corpus to the training command and run it again."); bufferedBatches_ = std::move(futureBufferedBatches_.get()); // if bg thread returns an empty swath, we hit the end of the epoch - if (bufferedBatches_.empty() || saveAndExit()) { + if (bufferedBatches_.empty() || saveAndExitRequested()) { return nullptr; } // and kick off the next bg operation @@ -262,7 +262,7 @@ class BatchGenerator : public RNGEngine { } else { // don't spawn any threads, i.e. batch fetching is blocking. bufferedBatches_ = fetchBatches(); // if bufferedBatches is empty we hit the end of the epoch - if (bufferedBatches_.empty() || saveAndExit()) { + if (bufferedBatches_.empty() || saveAndExitRequested()) { return nullptr; } } diff --git a/src/training/scheduler.h b/src/training/scheduler.h index 8e405ff30..a3828cd32 100755 --- a/src/training/scheduler.h +++ b/src/training/scheduler.h @@ -155,7 +155,7 @@ class Scheduler : public TrainingObserver { } bool keepGoing() { - if(saveAndExit()) // via SIGTERM + if(saveAndExitRequested()) // via SIGTERM return false; // stop if it reached the maximum number of epochs @@ -189,7 +189,7 @@ class Scheduler : public TrainingObserver { void started() { LOG(info, "Training started"); } void finished() { - if (saveAndExit()) + if (saveAndExitRequested()) LOG(info, "Training interrupted (via signal)."); else LOG(info, "Training finished"); @@ -222,7 +222,7 @@ class Scheduler : public TrainingObserver { // Do not validate if already validated (for instance, after the model is loaded) // or if validation is scheduled for another update, or when a graceful shutdown // was requested. - if(saveAndExit() + if(saveAndExitRequested() || state_->validated // already validated (in resumed training, for example) || (!state_->enteredNewPeriodOf(options_->get("valid-freq")) && !isFinal)) // not now return;