From c7ba58f07c6583614f74d906bbce796e24042522 Mon Sep 17 00:00:00 2001 From: Brett Viren Date: Mon, 1 Nov 2021 16:12:39 -0400 Subject: [PATCH 01/46] Initial draft of a DFT interface and FFTW implementation --- aux/inc/WireCellAux/FftwDFT.h | 42 ++++++++++++ aux/src/FftwDFT.cxx | 113 +++++++++++++++++++++++++++++++ iface/README.org | 59 ++++++++++++++++ iface/inc/WireCellIface/IDFT.h | 56 +++++++++++++++ iface/src/IDFT.cxx | 25 +++++++ util/inc/WireCellUtil/Waveform.h | 4 ++ util/src/Waveform.cxx | 16 +++++ 7 files changed, 315 insertions(+) create mode 100644 aux/inc/WireCellAux/FftwDFT.h create mode 100644 aux/src/FftwDFT.cxx create mode 100644 iface/inc/WireCellIface/IDFT.h create mode 100644 iface/src/IDFT.cxx diff --git a/aux/inc/WireCellAux/FftwDFT.h b/aux/inc/WireCellAux/FftwDFT.h new file mode 100644 index 000000000..ad265eef7 --- /dev/null +++ b/aux/inc/WireCellAux/FftwDFT.h @@ -0,0 +1,42 @@ +#ifndef WIRECELLAUX_FFTWDFT +#define WIRECELLAUX_FFTWDFT + +#include "WireCellIface/IDFT.h" + +namespace WireCell::Aux { + + /** + FftwDFT provides IDFT based on FFTW3. + */ + class FftwDFT : public IDFT { + public: + + FftwDFT(); + virtual ~FftwDFT(); + + // 1d + + virtual + void fwd1d(const complex_t* in, complex_t* out, + int stride) const = 0; + + virtual + void inv1d(const complex_t* in, complex_t* out, + int stride) const = 0; + + // batched 1D ("1b") - rely on base implementation + + // 2d + + virtual + void fwd2d(const complex_t* in, complex_t* out, + int stride, int nstrides) const = 0; + virtual + void inv2d(const complex_t* in, complex_t* out, + int stride, int nstrides) const = 0; + + + }; +} + +#endif diff --git a/aux/src/FftwDFT.cxx b/aux/src/FftwDFT.cxx new file mode 100644 index 000000000..85f6b0820 --- /dev/null +++ b/aux/src/FftwDFT.cxx @@ -0,0 +1,113 @@ +#include "WireCellAux/FftwDFT.h" +#include +#include +#include +#include + + +using namespace WireCell; + +using plan_key_t = int64_t; +using plan_type = fftwf_plan; +using plan_map_t = std::unordered_map; +using plan_val_t = fftwf_complex; + +static +plan_key_t make_key(bool inverse, const void * src, void * dst, int n0, int n1) +{ + bool inplace = (dst==src); + bool aligned = ( (reinterpret_cast(src)&15) | (reinterpret_cast(dst)&15) ) == 0; + int64_t key = ( ( (((int64_t)n0) << 30)|(n1<<3 ) | (inverse<<2) | (inplace<<1) | aligned ) << 1 ) + 1; + return key; +} + +static +plan_type get_plan(std::shared_mutex& mutex, plan_map_t& plans, plan_key_t key) +{ + std::shared_lock lock(mutex); + auto it = plans.find(key); + if (it == plans.end()) { + return NULL; + } + return it->second; +} + + +template +void doit(std::shared_mutex& mutex, plan_map_t& plans, + int fwdrev, plan_val_t* src, plan_val_t* dst, int stride, int nstrides, + planner_function make_plan) +{ + auto key = make_key(fwdrev == FFTW_BACKWARD, src, dst, stride, nstrides); + auto plan = get_plan(mutex, plans, key); + if (!plan) { + std::unique_lock lock(mutex); + // Check again in case another thread snakes us. + auto it = plans.find(key); + if (it == plans.end()) { + plan = make_plan(); + plans[key] = plan; + } + else { + plan = it->second; + } + } + fftwf_execute_dft(plan, src, dst); +} + + +static +plan_val_t* pval_cast( const IDFT::complex_t * p) +{ + return const_cast( reinterpret_cast(p) ); +} + + +void Aux::FftwDFT::fwd1d(const complex_t* in, complex_t* out, int stride) const +{ + static std::shared_mutex mutex; + static plan_map_t plans; + static const int dir = FFTW_FORWARD; + auto src = pval_cast(in); + auto dst = pval_cast(out); + doit(mutex, plans, dir, src, dst, stride, 0, [&]( ) { + return fftwf_plan_dft_1d(stride, src, dst, dir, FFTW_ESTIMATE|FFTW_PRESERVE_INPUT); + }); +} +void Aux::FftwDFT::inv1d(const complex_t* in, complex_t* out, int stride) const +{ + static std::shared_mutex mutex; + static plan_map_t plans; + static const int dir = FFTW_BACKWARD; + auto src = pval_cast(in); + auto dst = pval_cast(out); + doit(mutex, plans, dir, src, dst, stride, 0, [&]( ) { + return fftwf_plan_dft_1d(stride, src, dst, dir, FFTW_ESTIMATE|FFTW_PRESERVE_INPUT); + }); +} + + +void Aux::FftwDFT::fwd2d(const complex_t* in, complex_t* out, int stride, int nstrides) const +{ + static std::shared_mutex mutex; + static plan_map_t plans; + static const int dir = FFTW_FORWARD; + auto src = pval_cast(in); + auto dst = pval_cast(out); + doit(mutex, plans, dir, src, dst, stride, nstrides, [&]( ) { + return fftwf_plan_dft_2d(stride, nstrides, src, dst, dir, FFTW_ESTIMATE|FFTW_PRESERVE_INPUT); + }); +} + + +void Aux::FftwDFT::inv2d(const complex_t* in, complex_t* out, int stride, int nstrides) const +{ + static std::shared_mutex mutex; + static plan_map_t plans; + static const int dir = FFTW_BACKWARD; + auto src = pval_cast(in); + auto dst = pval_cast(out); + doit(mutex, plans, dir, src, dst, stride, nstrides, [&]( ) { + return fftwf_plan_dft_2d(stride, nstrides, src, dst, dir, FFTW_ESTIMATE|FFTW_PRESERVE_INPUT); + }); +} diff --git a/iface/README.org b/iface/README.org index 37e27b412..cf76f4926 100644 --- a/iface/README.org +++ b/iface/README.org @@ -25,3 +25,62 @@ the overall WCT dependency tree. Discussion is warranted in these cases. See the user manual for more info. https://wirecell.bnl.gov/ +* Interfaces + +** IDFT + +The ~IDFT~ class provides interface to methods to perform discrete +Fourier transforms on arrays of complex single precision floating +point values. + +The interface defines a number of methods which take a general naming +convention like: +#+begin_example +void (...); +#+end_example + +The "direction" of the transform is one of + +- fwd :: the DFT is from interval to frequency, no normalization. +- inv :: the DFT is from frequency to interval, 1/n normalization. + +The "domain" determines the dimension of array and how it is transformed + +- 1d :: a 1D array is transformed +- 1b :: a batch of equal-length 1D arrays are transformed +- 2d :: a 2D array is transformed (along both dimensions) + +The shape of 2D arrays (~1b~ or ~2d~ methods) are given in terms of two +numbers: ~stride~ and ~nstrides~. The number ~stride~ counts the number of +contiguous array elements along one dimension and ~nstrides~ counts the +number non-contiguous elements logically along the opposite dimension. +In the case of "row-major" aka "C" memory ordering of 2D arrays, the +number ~stride~ counts the number of elements in one "row" and ~nstrides~ +counts the number of rows, aka, the number of elements in one column. + +The ~1b~ transforms operate along a contiguous array of length ~stride~. +By default, these transforms are implemented in terms of ~nstrides~ +calls to the ~1d~ DFT interface method. The implementation may override +the ~1b~ default methods for example to exploit some kind of "batch +optimization". + +*** Limitations + +- The potential speed up when the input to a forward or output from + reverse is real valued is not possible to implement with ~IDFT~. It + requires the caller to take particular care in array sizes and would + double the number of methods. + +- To satisfy the low-level pointer to memory interface from higher + level objects see the ~Waveform.h~ and ~Array.h~ headers in + ~WireCellUtil~. In particular, see functions there to lift real to + complex or perform memory transforms. + +- Interface to higher order transforms, such as convolutions, are not + provided. See ~Aux::DFT~ for implementations in terms of an ~IDFT~. + +** ... + +Any interfaces not listed above, please see their header file in +[[file:inc/WireCellIface/][inc/WireCellIface/]] for more information. + diff --git a/iface/inc/WireCellIface/IDFT.h b/iface/inc/WireCellIface/IDFT.h new file mode 100644 index 000000000..dd1d39e8c --- /dev/null +++ b/iface/inc/WireCellIface/IDFT.h @@ -0,0 +1,56 @@ +/** + Interface to perform discrete single-precision Fourier transforms. +*/ + +#ifndef WIRECELL_IDFT +#define WIRECELL_IDFT + +#include "WireCellUtil/IComponent.h" +#include + +namespace WireCell { + + class IDFT : public IComponent { + public: + virtual ~IDFT(); + + /// The type for the signal in each bin. + using scalar_t = float; + + /// The type for the spectrum in each bin. + using complex_t = std::complex; + + // 1D + + virtual + void fwd1d(const complex_t* in, complex_t* out, + int stride) const = 0; + + virtual + void inv1d(const complex_t* in, complex_t* out, + int stride) const = 0; + + // batched 1D ("1b") + + virtual + void fwd1b(const complex_t* in, complex_t* out, + int stride, int nstrides) const; + virtual + void inv1b(const complex_t* in, complex_t* out, + int stride, int nstrides) const; + + + // 2D, transform both dimensions + + virtual + void fwd2d(const complex_t* in, complex_t* out, + int stride, int nstrides) const = 0; + virtual + void inv2d(const complex_t* in, complex_t* out, + int stride, int nstrides) const = 0; + + }; +} + + +#endif diff --git a/iface/src/IDFT.cxx b/iface/src/IDFT.cxx new file mode 100644 index 000000000..2f9ee3543 --- /dev/null +++ b/iface/src/IDFT.cxx @@ -0,0 +1,25 @@ +#include "WireCellIface/IDFT.h" + +using namespace WireCell; + +IDFT::~IDFT() {} + +// Trivial default "batched" implementations. If your concrete +// implementation provides some kind of "batch optimization", such as +// with some GPU FFTs, override these methods! + +void IDFT::fwd1b(const complex_t* in, complex_t* out, + int stride, int nstrides) const +{ + for (int istride=0; istride diff --git a/util/src/Waveform.cxx b/util/src/Waveform.cxx index 6cb356f91..21ce70fa3 100644 --- a/util/src/Waveform.cxx +++ b/util/src/Waveform.cxx @@ -69,6 +69,22 @@ Waveform::realseq_t WireCell::Waveform::phase(const Waveform::compseq_t& seq) return c2r(seq, [](Waveform::complex_t c) { return std::arg(c); }); } + +Waveform::compseq_t Waveform::complex(const Waveform::realseq_t& real) +{ + Waveform::realseq_t imag(real.size(), 0); + return Waveform::complex(real, imag); +} + +Waveform::compseq_t Waveform::complex(const Waveform::realseq_t& real, const Waveform::realseq_t& imag) +{ + Waveform::compseq_t ret(real.size()); + std::transform(real.begin(), real.end(), imag.begin(), ret.begin(), + [](real_t re, real_t im) { return Waveform::complex_t(re,im); } ); + return ret; +} + + Waveform::real_t WireCell::Waveform::median(Waveform::realseq_t& wave) { return percentile(wave, 0.5); } Waveform::real_t WireCell::Waveform::median_binned(Waveform::realseq_t& wave) { return percentile_binned(wave, 0.5); } From 59be25c7eda82233a7672748fefc3c4d059be833 Mon Sep 17 00:00:00 2001 From: Brett Viren Date: Tue, 2 Nov 2021 12:10:17 -0400 Subject: [PATCH 02/46] Tell Boost to shut up with the internal deprecation warnings --- util/inc/WireCellUtil/IndexedGraph.h | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/util/inc/WireCellUtil/IndexedGraph.h b/util/inc/WireCellUtil/IndexedGraph.h index 6c6eb54d3..35c8513ed 100644 --- a/util/inc/WireCellUtil/IndexedGraph.h +++ b/util/inc/WireCellUtil/IndexedGraph.h @@ -13,6 +13,18 @@ #ifndef WIRECELL_INDEXEDGRAPH #define WIRECELL_INDEXEDGRAPH +// fixme: watchme: Boost started to deprecate some internal header +// inclusion which is not, as best as I can tell, any of our problem. +// The message is: +// +// ../../../../../opt/boost-1-76-0/include/boost/config/pragma_message.hpp:24:34: note: ‘#pragma message: This header is deprecated. Use instead.’ +// +// This arises from a deeply nested #include well beyond anything +// which is obvious here. +// +// If/when this is cleaned up in Boost, remove this comment and the +// next line. +#define BOOST_ALLOW_DEPRECATED_HEADERS 1 #include #include #include From a9787bb9d858c2665429b2ae095576097cbade8f Mon Sep 17 00:00:00 2001 From: Brett Viren Date: Tue, 2 Nov 2021 12:50:57 -0400 Subject: [PATCH 03/46] Tell Boost to shut up with the internal deprecation warnings --- util/inc/WireCellUtil/String.h | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/util/inc/WireCellUtil/String.h b/util/inc/WireCellUtil/String.h index ac06e6eb1..d065a5964 100644 --- a/util/inc/WireCellUtil/String.h +++ b/util/inc/WireCellUtil/String.h @@ -1,6 +1,18 @@ #ifndef WIRECELLUTIL_STRING #define WIRECELLUTIL_STRING +// fixme: watchme: Boost started to deprecate some internal header +// inclusion which is not, as best as I can tell, any of our problem. +// The message is: +// +// ../../../../../opt/boost-1-76-0/include/boost/config/pragma_message.hpp:24:34: note: ‘#pragma message: This header is deprecated. Use instead.’ +// +// This arises from a deeply nested #include well beyond anything +// which is obvious here. +// +// If/when this is cleaned up in Boost, remove this comment and the +// next line. +#define BOOST_ALLOW_DEPRECATED_HEADERS 1 #include #include From 08815ba37658625f40c46c2b2f2a605466a12d95 Mon Sep 17 00:00:00 2001 From: Brett Viren Date: Tue, 2 Nov 2021 12:51:22 -0400 Subject: [PATCH 04/46] Add stack trace to exception what() --- util/inc/WireCellUtil/Exceptions.h | 23 +++++++++++++++++++++-- util/test/test_exceptions.cxx | 3 ++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/util/inc/WireCellUtil/Exceptions.h b/util/inc/WireCellUtil/Exceptions.h index 955ab62c9..54feb39cb 100644 --- a/util/inc/WireCellUtil/Exceptions.h +++ b/util/inc/WireCellUtil/Exceptions.h @@ -19,17 +19,36 @@ #define WIRECELL_EXCEPTIONS #include +#include #include #include -#define THROW(e) BOOST_THROW_EXCEPTION(e) +using stack_traced_t = boost::error_info; +// template +// void throw_with_trace(const E& e) { +// BOOST_THROW_EXCEPTION(boost::enable_error_info(e) << stack_traced_t(boost::stacktrace::stacktrace())); +// } +// #define THROW(e) throw_with_trace(e) +#define THROW(e) BOOST_THROW_EXCEPTION(boost::enable_error_info(e) << stack_traced_t(boost::stacktrace::stacktrace())) +//#define THROW(e) BOOST_THROW_EXCEPTION(e) #define errstr(e) boost::diagnostic_information(e) + namespace WireCell { + // Get the stacktrace as an object. You must test for non-nullptr. + // Or, just rely on e.what(). + inline + const boost::stacktrace::stacktrace* stacktrace(const std::exception& e) { + return boost::get_error_info(e); + } + + /// The base wire cell exception. struct Exception : virtual public std::exception, virtual boost::exception { - char const *what() const throw() { return diagnostic_information_what(*this); } + char const *what() const throw() { + return diagnostic_information_what(*this); + } }; /// Thrown when a wrong value has been encountered. diff --git a/util/test/test_exceptions.cxx b/util/test/test_exceptions.cxx index 11d4553fb..a79b44321 100644 --- a/util/test/test_exceptions.cxx +++ b/util/test/test_exceptions.cxx @@ -22,6 +22,7 @@ int main() THROW(ValueError() << errmsg{format("some error with value=%d msg=\"%s\"", value, omg)}); } catch (ValueError& e) { - cerr << "caught ValueError: " << errstr(e) << endl; + cerr << "Caught:\n"; + cerr << e.what() << "\n"; } } From a7b6f97bb91066d9a2fd42e3d48859e1776a3bde Mon Sep 17 00:00:00 2001 From: Brett Viren Date: Tue, 2 Nov 2021 12:52:11 -0400 Subject: [PATCH 05/46] Throw instead of returning garbage on garbage input. This can only be an improvement as the garbage return value was being ignored by all callers. --- util/src/Waveform.cxx | 5 ++-- util/test/test_issue24.cxx | 55 +++++++++++++++++++++++++++++++++++--- 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/util/src/Waveform.cxx b/util/src/Waveform.cxx index 21ce70fa3..7817740e5 100644 --- a/util/src/Waveform.cxx +++ b/util/src/Waveform.cxx @@ -1,4 +1,5 @@ #include "WireCellUtil/Waveform.h" +#include "WireCellUtil/Exceptions.h" #include @@ -92,11 +93,11 @@ Waveform::real_t WireCell::Waveform::median_binned(Waveform::realseq_t& wave) { Waveform::real_t WireCell::Waveform::percentile(Waveform::realseq_t& wave, real_t percentage) { if (percentage < 0.0 or percentage > 1.0) { - return -9999; + THROW(ValueError() << errmsg{"percentage out of range"}); } const size_t siz = wave.size(); if (siz == 0) { - return -9999; + THROW(ValueError() << errmsg{"empty waveform"}); } if (siz == 1) { return wave[0]; diff --git a/util/test/test_issue24.cxx b/util/test/test_issue24.cxx index 7f67778be..d7ce8c9f5 100644 --- a/util/test/test_issue24.cxx +++ b/util/test/test_issue24.cxx @@ -1,25 +1,72 @@ #include "WireCellUtil/Waveform.h" +#include "WireCellUtil/Exceptions.h" + #include using namespace std; using namespace WireCell::Waveform; +using namespace WireCell; int main() { int nsamples = 10; - while (nsamples >= 0) { + while (nsamples > 0) { realseq_t wave(nsamples, 0); median(wave); --nsamples; } + cerr << "Testing error handling\n"; realseq_t wave; - assert(-9999 == median(wave)); + bool okay = false; + try { + median(wave); + } + catch (ValueError& err) { + okay = true; + cerr << "Caught:\n" << err.what() << "\nOKAY\n"; + } + catch (std::exception& err) { + cerr << "Why am I here?\n"; + cerr << err.what() << "\n"; + } + if (!okay) { + cerr << "median of empty wave should throw\n"; + } + assert(okay); + cerr << "thrown and caught empty waveform\n"; + wave.push_back(6.9); wave.push_back(9.6); - assert(-9999 == percentile(wave, -0.1)); - assert(-9999 == percentile(wave, 1.1)); + okay = false; + try { + percentile(wave, -0.1); + } + catch (ValueError& err) { + okay = true; + cerr << "Caught:\n" << err.what() << "\nOKAY\n"; + } + if (!okay) { + cerr << "median under percentage should throw\n"; + } + assert(okay); + cerr << "thrown and caught median under percentage\n"; + + okay = false; + try { + percentile(wave, 1.1); + } + catch (ValueError& err) { + okay = true; + cerr << "Caught:\n" << err.what() << "\nOKAY\n"; + } + if (!okay) { + cerr << "median over percentage should throw\n"; + } + assert(okay); + cerr << "thrown and caught median over percentage\n"; + cerr << median(wave) << endl; assert(std::abs(9.6 - median(wave)) < 0.001); wave.push_back(0.0); From dd4dffb55c7f922d1cdc0e26292135affe312cfb Mon Sep 17 00:00:00 2001 From: Brett Viren Date: Tue, 2 Nov 2021 16:04:08 -0400 Subject: [PATCH 06/46] Start on higher-level dft functions --- aux/inc/WireCellAux/DftTools.h | 79 ++++++++++++++++++++ aux/inc/WireCellAux/FftwDFT.h | 8 +- aux/src/DftTools.cxx | 82 +++++++++++++++++++++ aux/src/FftwDFT.cxx | 14 ++++ aux/test/test_dfttools.cxx | 28 +++++++ aux/test/test_idft.cxx | 122 +++++++++++++++++++++++++++++++ util/inc/WireCellUtil/Waveform.h | 3 +- util/test/test_complex.cxx | 36 +++++++++ 8 files changed, 367 insertions(+), 5 deletions(-) create mode 100644 aux/inc/WireCellAux/DftTools.h create mode 100644 aux/src/DftTools.cxx create mode 100644 aux/test/test_dfttools.cxx create mode 100644 aux/test/test_idft.cxx create mode 100644 util/test/test_complex.cxx diff --git a/aux/inc/WireCellAux/DftTools.h b/aux/inc/WireCellAux/DftTools.h new file mode 100644 index 000000000..24e097a43 --- /dev/null +++ b/aux/inc/WireCellAux/DftTools.h @@ -0,0 +1,79 @@ +/** + High level functions related to DFTs. + + Most take an IDFT::pointer to a DFT implementation and return an + allocated result. Use IDFT directly to control allocation. + + There are std::vector and Eigen array functions. + + Abbreviations: + + - IS is interval space aka time / distance + - FS is frequency space aka frequency / periodicity + + Price to pay for simple API is a lack of optimizations: + + - When a real valued array is invovled, all arrays are full size. + That is, no half-size optimization will be exposed to the caller. + + - These functions tend to make more copies than may be needed if + IDFT is called directly. In addition to real/complex conversion, + using std::vector or Eigen array instead of raw memory leads to + more copies. + */ + +#ifndef WIRECELL_AUX_DFTTOOLS +#define WIRECELL_AUX_DFTTOOLS + +#include "WireCellIface/IDFT.h" +#include +#include + +namespace WireCell::Aux { + + using real_t = IDFT::scalar_t; + using complex_t = IDFT::complex_t; + + // std::vector based functions + + using realvec_t = std::vector; + using compvec_t = std::vector; + + // 1D with vectors + + // Transform a real IS, return same size FS. + compvec_t dft(IDFT::pointer dft, const realvec_t& seq); + + // Transform complex FS to IS and return real part + realvec_t idft(IDFT::pointer dft, const compvec_t& spec); + + compvec_t r2c(const realvec_t& r); + realvec_t c2r(const compvec_t& c); + + + // Eigen array based functions + + /// Real 1D array + using array_xf = Eigen::ArrayXf; + + /// Complex 1D array + using array_xc = Eigen::ArrayXcf; + + /// A real, 2D array + using array_xxf = Eigen::ArrayXXf; + + /// A complex, 2D array + using array_xxc = Eigen::ArrayXXcf; + + // 2D with Eigen arrays + + // Transform a real IS, return same size FS. + array_xxc dft(IDFT::pointer dft, const array_xxf& arr); + + // Transform complex FS to IS and return real part + array_xxf idft(IDFT::pointer dft, const array_xxc& arr); + + +} + +#endif diff --git a/aux/inc/WireCellAux/FftwDFT.h b/aux/inc/WireCellAux/FftwDFT.h index ad265eef7..62cea9a84 100644 --- a/aux/inc/WireCellAux/FftwDFT.h +++ b/aux/inc/WireCellAux/FftwDFT.h @@ -18,11 +18,11 @@ namespace WireCell::Aux { virtual void fwd1d(const complex_t* in, complex_t* out, - int stride) const = 0; + int stride) const; virtual void inv1d(const complex_t* in, complex_t* out, - int stride) const = 0; + int stride) const; // batched 1D ("1b") - rely on base implementation @@ -30,10 +30,10 @@ namespace WireCell::Aux { virtual void fwd2d(const complex_t* in, complex_t* out, - int stride, int nstrides) const = 0; + int stride, int nstrides) const; virtual void inv2d(const complex_t* in, complex_t* out, - int stride, int nstrides) const = 0; + int stride, int nstrides) const; }; diff --git a/aux/src/DftTools.cxx b/aux/src/DftTools.cxx new file mode 100644 index 000000000..b7ec12cb5 --- /dev/null +++ b/aux/src/DftTools.cxx @@ -0,0 +1,82 @@ +#include "WireCellAux/DftTools.h" + +using namespace WireCell; +using namespace WireCell::Aux; + +compvec_t Aux::r2c(const realvec_t& r) +{ + compvec_t cret(r.size()); + std::transform(r.begin(), r.end(), cret.begin(), + [](const real_t& r) { return complex_t(r, 0); }); + return cret; +} +realvec_t Aux::c2r(const compvec_t& c) +{ + realvec_t rret(c.size()); + std::transform(c.begin(), c.end(), rret.begin(), + [](const complex_t& c) { return std::real(c); }); + return rret; +} + +// Transform a real IS, return same size FS. +compvec_t Aux::dft(IDFT::pointer dft, const realvec_t& seq) +{ + compvec_t cseq = Aux::r2c(seq); + compvec_t cret(cseq.size()); + dft->fwd1d(cseq.data(), cret.data(), cret.size()); + return cret; +} + +// Transform complex FS to IS and return real part +realvec_t Aux::idft(IDFT::pointer dft, const compvec_t& spec) +{ + compvec_t cret(spec.size()); + dft->inv1d(spec.data(), cret.data(), cret.size()); + return Aux::c2r(cret); +} + +using array_xxf_rm = Eigen::Array; +using array_xxc_rm = Eigen::Array; + + +// Transform a real IS, return same size FS. +array_xxc Aux::dft(IDFT::pointer trans, const array_xxf& arr) +{ + int stride = arr.rows(); + int nstrides = arr.cols(); + array_xxc ret(stride, nstrides); + + if (!arr.IsRowMajor) { + stride = arr.cols(); + nstrides = arr.rows(); + } + + size_t size = stride*nstrides; + compvec_t carr(size); + std::transform(arr.data(), arr.data()+size, carr.begin(), + [](const real_t& r) { return complex_t(r,0); }); + + trans->fwd2d(carr.data(), ret.data(), stride, nstrides); + return ret; +} + +// Transform complex FS to IS and return real part +array_xxf Aux::idft(IDFT::pointer trans, const array_xxc& arr) +{ + int stride = arr.rows(); + int nstrides = arr.cols(); + array_xxf ret(stride, nstrides); + + if (!arr.IsRowMajor) { + stride = arr.cols(); + nstrides = arr.rows(); + } + + size_t size = stride*nstrides; + compvec_t cret(size); + trans->inv2d(arr.data(), cret.data(), stride, nstrides); + + std::transform(cret.begin(), cret.end(), ret.data(), + [](const complex_t& c) { return std::real(c); }); + return ret; +} diff --git a/aux/src/FftwDFT.cxx b/aux/src/FftwDFT.cxx index 85f6b0820..2a655f5d6 100644 --- a/aux/src/FftwDFT.cxx +++ b/aux/src/FftwDFT.cxx @@ -1,9 +1,13 @@ #include "WireCellAux/FftwDFT.h" +#include "WireCellUtil/NamedFactory.h" + #include #include #include #include +WIRECELL_FACTORY(FftwDFT, WireCell::Aux::FftwDFT, WireCell::IDFT) + using namespace WireCell; @@ -33,6 +37,8 @@ plan_type get_plan(std::shared_mutex& mutex, plan_map_t& plans, plan_key_t key) } +// #include // debugging + template void doit(std::shared_mutex& mutex, plan_map_t& plans, int fwdrev, plan_val_t* src, plan_val_t* dst, int stride, int nstrides, @@ -45,6 +51,7 @@ void doit(std::shared_mutex& mutex, plan_map_t& plans, // Check again in case another thread snakes us. auto it = plans.find(key); if (it == plans.end()) { + //std::cerr << "make plan for " << key << std::endl; plan = make_plan(); plans[key] = plan; } @@ -111,3 +118,10 @@ void Aux::FftwDFT::inv2d(const complex_t* in, complex_t* out, int stride, int ns return fftwf_plan_dft_2d(stride, nstrides, src, dst, dir, FFTW_ESTIMATE|FFTW_PRESERVE_INPUT); }); } +Aux::FftwDFT::FftwDFT() +{ +} +Aux::FftwDFT::~FftwDFT() +{ +} + diff --git a/aux/test/test_dfttools.cxx b/aux/test/test_dfttools.cxx new file mode 100644 index 000000000..0ef2eaf22 --- /dev/null +++ b/aux/test/test_dfttools.cxx @@ -0,0 +1,28 @@ +#include "WireCellAux/DftTools.h" +#include "WireCellAux/FftwDFT.h" + +#include +#include + +using namespace WireCell; +using namespace WireCell::Aux; + +void test_1d_imp(IDFT::pointer trans) +{ + realvec_t rimp(64, 0); + rimp[1] = 1.0; + auto cimp = dft(trans, rimp); + for (auto c : cimp) { + std::cerr << c << " "; + } + std::cerr << "\n"; +} + +int main() +{ + auto trans = std::make_shared(); + + test_1d_imp(trans); + + return 0; +} diff --git a/aux/test/test_idft.cxx b/aux/test/test_idft.cxx new file mode 100644 index 000000000..97f85b14d --- /dev/null +++ b/aux/test/test_idft.cxx @@ -0,0 +1,122 @@ +// Test IDFT implementations. +#include "WireCellUtil/NamedFactory.h" +#include "WireCellUtil/Waveform.h" +#include "WireCellUtil/PluginManager.h" +#include "WireCellIface/IConfigurable.h" +#include "WireCellIface/IDFT.h" + +#include +#include +#include +#include + +using namespace WireCell; + + +static +void test_1d_zero(IDFT::pointer dft, int size = 1024) +{ + std::vector inter(size,0), freq(size,0); + + dft->fwd1d(inter.data(), freq.data(), inter.size()); + dft->inv1d(freq.data(), inter.data(), freq.size()); + + auto tot = Waveform::sum(inter); + assert(std::real(tot) == 0); +} +static +void test_2d_zero(IDFT::pointer dft, int size = 1024) +{ + int stride=size, nstrides=size; + std::vector inter(stride*nstrides,0); + std::vector freq(stride*nstrides,0); + + dft->fwd2d(inter.data(), freq.data(), stride, nstrides); + dft->inv2d(freq.data(), inter.data(), stride, nstrides); + + auto tot = Waveform::sum(inter); + assert(std::real(tot) == 0); +} + +void fwdrev(IDFT::pointer dft, int id, int ntimes, int size) +{ + int stride=size, nstrides=size; + std::vector inter(stride*nstrides,0); + std::vector freq(stride*nstrides,0); + + // std::cerr << "running " << id << std::endl; + + while (ntimes) { + //std::cerr << ntimes << "\n"; + dft->fwd2d(inter.data(), freq.data(), stride, nstrides); + dft->inv2d(freq.data(), inter.data(), stride, nstrides); + + --ntimes; + auto tot = Waveform::sum(inter); + assert(std::real(tot) == 0); + } + //std::cerr << "finished " << id << std::endl; +} + +static +void test_2d_threads(IDFT::pointer dft, int nthreads, int nloops, int size = 1024) +{ + using namespace std::chrono; + + steady_clock::time_point t1 = steady_clock::now(); + + std::vector workers; + + //std::cerr << "Starting workers\n"; + for (int ind=0; ind dt1 = duration_cast>(t2 - t1); + std::cerr << "ndfts: " << nthreads*nloops + << " " << nthreads << " " << nloops + << " " << dt1.count() << std::endl; +} + +int main(int argc, char* argv[]) +{ + // fixme, add CLI parsing to add plugins, config and name another + // dft. For now, just use the one in aux. + PluginManager& pm = PluginManager::instance(); + pm.add("WireCellAux"); + std::string dft_tn = "FftwDFT"; + + // creates + auto idft = Factory::lookup_tn(dft_tn); + assert(idft); + { // configure before use if configurable + auto icfg = Factory::find_maybe_tn(dft_tn); + if (icfg) { + auto cfg = icfg->default_configuration(); + icfg->configure(cfg); + } + } + + test_1d_zero(idft); + test_2d_zero(idft); + + std::vector sizes = {128,256,512,1024}; + for (auto size : sizes) { + int ndouble=3, ntot=2*16384/size; + while (ndouble) { + int nthread = 1< #include -// for FFT +// FIXME: remove the hard-wired Eigen::FFT related in favor of dynamic +// DFT with Aux/DftTools.h. #include #include diff --git a/util/test/test_complex.cxx b/util/test/test_complex.cxx new file mode 100644 index 000000000..877138020 --- /dev/null +++ b/util/test/test_complex.cxx @@ -0,0 +1,36 @@ +#include +#include +#include + +int main() +{ + // note: this compiles but doesn't do what you may expect. + // complex numbers are a 2-array of doubles: [r,i] so the + // reinterpret_cast from complex to double gives an "interleaved" + // array of [r0,i0,r1,i1]. Likewise from double to complex gives + // a "complex" number of [r0 ,r1]. + + using complex_t = std::complex; + using cvec = std::vector; + using dvec = std::vector; + + cvec c1{{0,0}, {1,1}}; + dvec d1={0,1}; + + complex_t* c2 = reinterpret_cast(d1.data()); + cvec c3(c2, c2+2); + + double* d2 = reinterpret_cast(c1.data()); + dvec d3(d2, d2+2); + + for (auto c : c3) { + std::cerr << c << " "; + } + std::cerr << "\n"; + for (auto d : d3) { + std::cerr << d << " "; + } + std::cerr << "\n"; + + return 0; +} From dbcbe9bf681adf836c3adec033162e36b0d95822 Mon Sep 17 00:00:00 2001 From: Brett Viren Date: Fri, 5 Nov 2021 18:33:06 -0400 Subject: [PATCH 07/46] Allow to optionally provide data and metadata in constructor --- aux/inc/WireCellAux/SimpleTensor.h | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/aux/inc/WireCellAux/SimpleTensor.h b/aux/inc/WireCellAux/SimpleTensor.h index 56c3ab905..7b267bdbe 100644 --- a/aux/inc/WireCellAux/SimpleTensor.h +++ b/aux/inc/WireCellAux/SimpleTensor.h @@ -2,7 +2,9 @@ #define WIRECELL_AUX_SIMPLETENSOR #include "WireCellIface/ITensor.h" + #include +#include namespace WireCell { @@ -13,14 +15,26 @@ namespace WireCell { public: typedef ElementType element_t; - SimpleTensor(const shape_t& shape) + // Create simple tensor, allocating space for data. If + // data given it must have at least as many elements as + // implied by shape and that span will be copied into + // allocated memory. + SimpleTensor(const shape_t& shape, + const element_t* data=nullptr, + const Configuration& md = Configuration()) { size_t nbytes = element_size(); - for (const auto& s : shape) { + m_shape = shape; + for (const auto& s : m_shape) { nbytes *= s; } - m_store.resize(nbytes); - m_shape = shape; + if (data) { + const std::byte* bytes = reinterpret_cast(data); + m_store.assign(bytes, bytes+nbytes); + } + else { + m_store.resize(nbytes); + } } virtual ~SimpleTensor() {} From 2ffbb84e97aeaaeebf3c0a9d2c037864136cded7 Mon Sep 17 00:00:00 2001 From: Brett Viren Date: Fri, 5 Nov 2021 18:33:23 -0400 Subject: [PATCH 08/46] More work toward DFT as a service --- aux/inc/WireCellAux/DftTools.h | 80 ++++++---------- aux/inc/WireCellAux/FftwDFT.h | 4 +- aux/inc/WireCellAux/TensorTools.h | 77 +++++++++++++++ aux/src/DftTools.cxx | 103 ++++++++------------ aux/src/FftwDFT.cxx | 25 +++-- aux/test/test_dfttools.cxx | 103 ++++++++++++++++++-- aux/test/test_idft.cxx | 8 +- aux/test/test_tensor_tools.cxx | 102 ++++++++++++++++++++ iface/inc/WireCellIface/IDFT.h | 22 ++++- iface/inc/WireCellIface/ITensor.h | 7 +- iface/src/IDFT.cxx | 4 +- util/test/test_eigen_cast.cxx | 45 +++++++++ util/test/test_eigen_rowcol.cxx | 150 ++++++++++++++++++++++++++++++ 13 files changed, 590 insertions(+), 140 deletions(-) create mode 100644 aux/inc/WireCellAux/TensorTools.h create mode 100644 aux/test/test_tensor_tools.cxx create mode 100644 util/test/test_eigen_cast.cxx create mode 100644 util/test/test_eigen_rowcol.cxx diff --git a/aux/inc/WireCellAux/DftTools.h b/aux/inc/WireCellAux/DftTools.h index 24e097a43..10f2d8cbf 100644 --- a/aux/inc/WireCellAux/DftTools.h +++ b/aux/inc/WireCellAux/DftTools.h @@ -1,25 +1,6 @@ /** - High level functions related to DFTs. - - Most take an IDFT::pointer to a DFT implementation and return an - allocated result. Use IDFT directly to control allocation. - - There are std::vector and Eigen array functions. - - Abbreviations: - - - IS is interval space aka time / distance - - FS is frequency space aka frequency / periodicity - - Price to pay for simple API is a lack of optimizations: - - - When a real valued array is invovled, all arrays are full size. - That is, no half-size optimization will be exposed to the caller. - - - These functions tend to make more copies than may be needed if - IDFT is called directly. In addition to real/complex conversion, - using std::vector or Eigen array instead of raw memory leads to - more copies. + This provides std::vector and Eigen::Array typed interface to an + IDFT. */ #ifndef WIRECELL_AUX_DFTTOOLS @@ -31,47 +12,46 @@ namespace WireCell::Aux { - using real_t = IDFT::scalar_t; using complex_t = IDFT::complex_t; // std::vector based functions - using realvec_t = std::vector; - using compvec_t = std::vector; + using dft_vector_t = std::vector; // 1D with vectors - // Transform a real IS, return same size FS. - compvec_t dft(IDFT::pointer dft, const realvec_t& seq); - - // Transform complex FS to IS and return real part - realvec_t idft(IDFT::pointer dft, const compvec_t& spec); - - compvec_t r2c(const realvec_t& r); - realvec_t c2r(const compvec_t& c); + inline dft_vector_t fwd(IDFT::pointer dft, const dft_vector_t& seq) + { + dft_vector_t ret(seq.size()); + dft->fwd1d(seq.data(), ret.data(), ret.size()); + return ret; + } + inline dft_vector_t inv(IDFT::pointer dft, const dft_vector_t& spec) + { + dft_vector_t ret(spec.size()); + dft->inv1d(spec.data(), ret.data(), ret.size()); + return ret; + } // Eigen array based functions - /// Real 1D array - using array_xf = Eigen::ArrayXf; - - /// Complex 1D array - using array_xc = Eigen::ArrayXcf; - - /// A real, 2D array - using array_xxf = Eigen::ArrayXXf; - - /// A complex, 2D array - using array_xxc = Eigen::ArrayXXcf; + /// A complex, 2D array. Use Array::cast() if you need to + /// convert to/from real. + using dft_array_t = Eigen::ArrayXXcf; - // 2D with Eigen arrays - - // Transform a real IS, return same size FS. - array_xxc dft(IDFT::pointer dft, const array_xxf& arr); - - // Transform complex FS to IS and return real part - array_xxf idft(IDFT::pointer dft, const array_xxc& arr); + // 2D with Eigen arrays. Use eg arr.cast() to provde + // from real or arr.cast() to convert result to real. + + // Transform both dimesions. + dft_array_t fwd(IDFT::pointer dft, const dft_array_t& arr); + dft_array_t inv(IDFT::pointer dft, const dft_array_t& arr); + + // Transform one dimesions. For example axis=0 transforms each + // logical row of the Eigen array so that column=0 of each row + // would hold the frequency=0 component of each row's spectrum. + // array_xxc fwd(IDFT::pointer dft, const array_xxc& arr, int axis); + // array_xxc inv(IDFT::pointer dft, const array_xxc& arr, int axis); } diff --git a/aux/inc/WireCellAux/FftwDFT.h b/aux/inc/WireCellAux/FftwDFT.h index 62cea9a84..365190f0c 100644 --- a/aux/inc/WireCellAux/FftwDFT.h +++ b/aux/inc/WireCellAux/FftwDFT.h @@ -30,10 +30,10 @@ namespace WireCell::Aux { virtual void fwd2d(const complex_t* in, complex_t* out, - int stride, int nstrides) const; + int nstrides, int stride) const; virtual void inv2d(const complex_t* in, complex_t* out, - int stride, int nstrides) const; + int nstrides, int stride) const; }; diff --git a/aux/inc/WireCellAux/TensorTools.h b/aux/inc/WireCellAux/TensorTools.h new file mode 100644 index 000000000..6765b457e --- /dev/null +++ b/aux/inc/WireCellAux/TensorTools.h @@ -0,0 +1,77 @@ +#ifndef WIRECELL_AUX_TENSORTOOLS +#define WIRECELL_AUX_TENSORTOOLS + +#include "WireCellIface/ITensor.h" +#include "WireCellIface/IDFT.h" +#include "WireCellUtil/Exceptions.h" + +#include +#include + +namespace WireCell::Aux { + + bool is_row_major(const ITensor::pointer& ten) { + if (ten->order().empty() or ten->order()[0] == 1) { + return true; + } + return false; + } + + template + bool is_type(const ITensor::pointer& ten) { + return (ten->element_type() == typeid(scalar_t)); + } + + + // Extract the underlying data array from the tensor as a vector. + // Caution: this ignores storage order hints and 1D or 2D will be + // flattened assuming C-ordering, aka row-major (if 2D). It + // throws ValueError on type mismatch. + template + std::vector asvec(const ITensor::pointer& ten) + { + if (ten->element_type() != typeid(element_type)) { + THROW(ValueError() << errmsg{"element type mismatch"}); + } + const element_type* data = (const element_type*)ten->data(); + const size_t nelems = ten->size()/sizeof(element_type); + return std::vector(data, data+nelems); + } + + // Extract the tensor data as an Eigen array. + template + Eigen::Array // this default is column-wise + asarray(const ITensor::pointer& tens) + { + if (tens->element_type() != typeid(element_type)) { + THROW(ValueError() << errmsg{"element type mismatch"}); + } + using ROWM = Eigen::Array; + using COLM = Eigen::Array; + + auto shape = tens->shape(); + int nrows, ncols; + if (shape.size() == 1) { + nrows = 1; + ncols = shape[0]; + } + else { + nrows = shape[0]; + ncols = shape[1]; + } + + // Eigen::Map is a non-const view of data but a copy happens + // on return. We need to temporarily break const correctness. + const element_type* cdata = reinterpret_cast(tens->data()); + element_type* mdata = const_cast(cdata); + + if (is_row_major(tens)) { + return Eigen::Map(mdata, nrows, ncols); + } + // column-major + return Eigen::Map(mdata, nrows, ncols); + } + +} + +#endif diff --git a/aux/src/DftTools.cxx b/aux/src/DftTools.cxx index b7ec12cb5..5ef53e166 100644 --- a/aux/src/DftTools.cxx +++ b/aux/src/DftTools.cxx @@ -3,80 +3,55 @@ using namespace WireCell; using namespace WireCell::Aux; -compvec_t Aux::r2c(const realvec_t& r) -{ - compvec_t cret(r.size()); - std::transform(r.begin(), r.end(), cret.begin(), - [](const real_t& r) { return complex_t(r, 0); }); - return cret; -} -realvec_t Aux::c2r(const compvec_t& c) -{ - realvec_t rret(c.size()); - std::transform(c.begin(), c.end(), rret.begin(), - [](const complex_t& c) { return std::real(c); }); - return rret; -} - -// Transform a real IS, return same size FS. -compvec_t Aux::dft(IDFT::pointer dft, const realvec_t& seq) -{ - compvec_t cseq = Aux::r2c(seq); - compvec_t cret(cseq.size()); - dft->fwd1d(cseq.data(), cret.data(), cret.size()); - return cret; -} - -// Transform complex FS to IS and return real part -realvec_t Aux::idft(IDFT::pointer dft, const compvec_t& spec) -{ - compvec_t cret(spec.size()); - dft->inv1d(spec.data(), cret.data(), cret.size()); - return Aux::c2r(cret); -} - -using array_xxf_rm = Eigen::Array; -using array_xxc_rm = Eigen::Array; - - -// Transform a real IS, return same size FS. -array_xxc Aux::dft(IDFT::pointer trans, const array_xxf& arr) +/* + Big fat warning to future me: Passing by reference means the input + array may carry the .IsRowMajor optimization for implementing + transpose(). An extra copy would remove that complication but this + interface tries to keep it. + */ + +using ROWM = Eigen::Array; +using COLM = Eigen::Array; + +template +Aux::dft_array_t doit(const Aux::dft_array_t& arr, trans func) { + // Nominally, memory is in column-major order + const Aux::complex_t* in_data = arr.data(); int stride = arr.rows(); int nstrides = arr.cols(); - array_xxc ret(stride, nstrides); - if (!arr.IsRowMajor) { + // except when it isn't + bool flipped = arr.IsRowMajor; + if (flipped) { stride = arr.cols(); nstrides = arr.rows(); } - size_t size = stride*nstrides; - compvec_t carr(size); - std::transform(arr.data(), arr.data()+size, carr.begin(), - [](const real_t& r) { return complex_t(r,0); }); - - trans->fwd2d(carr.data(), ret.data(), stride, nstrides); - return ret; -} - -// Transform complex FS to IS and return real part -array_xxf Aux::idft(IDFT::pointer trans, const array_xxc& arr) -{ - int stride = arr.rows(); - int nstrides = arr.cols(); - array_xxf ret(stride, nstrides); + Aux::dft_vector_t out_vec(nstrides*stride); + func(in_data, out_vec.data(), nstrides, stride); - if (!arr.IsRowMajor) { - stride = arr.cols(); - nstrides = arr.rows(); + if (flipped) { + return Eigen::Map(out_vec.data(), arr.rows(), arr.cols()); } + return Eigen::Map(out_vec.data(), arr.rows(), arr.cols()); + +} - size_t size = stride*nstrides; - compvec_t cret(size); - trans->inv2d(arr.data(), cret.data(), stride, nstrides); +Aux::dft_array_t Aux::fwd(IDFT::pointer dft, const Aux::dft_array_t& arr) +{ + return doit(arr, [&](const complex_t* in_data, + complex_t* out_data, + int nstrides, int stride) { + dft->fwd2d(in_data, out_data, nstrides, stride); + }); +} - std::transform(cret.begin(), cret.end(), ret.data(), - [](const complex_t& c) { return std::real(c); }); - return ret; +Aux::dft_array_t Aux::inv(IDFT::pointer dft, const Aux::dft_array_t& arr) +{ + return doit(arr, [&](const complex_t* in_data, + complex_t* out_data, + int nstrides, int stride) { + dft->inv2d(in_data, out_data, nstrides, stride); + }); } diff --git a/aux/src/FftwDFT.cxx b/aux/src/FftwDFT.cxx index 2a655f5d6..f4a214760 100644 --- a/aux/src/FftwDFT.cxx +++ b/aux/src/FftwDFT.cxx @@ -41,7 +41,7 @@ plan_type get_plan(std::shared_mutex& mutex, plan_map_t& plans, plan_key_t key) template void doit(std::shared_mutex& mutex, plan_map_t& plans, - int fwdrev, plan_val_t* src, plan_val_t* dst, int stride, int nstrides, + int fwdrev, plan_val_t* src, plan_val_t* dst, int nstrides, int stride, planner_function make_plan) { auto key = make_key(fwdrev == FFTW_BACKWARD, src, dst, stride, nstrides); @@ -77,7 +77,7 @@ void Aux::FftwDFT::fwd1d(const complex_t* in, complex_t* out, int stride) const static const int dir = FFTW_FORWARD; auto src = pval_cast(in); auto dst = pval_cast(out); - doit(mutex, plans, dir, src, dst, stride, 0, [&]( ) { + doit(mutex, plans, dir, src, dst, 0, stride, [&]( ) { return fftwf_plan_dft_1d(stride, src, dst, dir, FFTW_ESTIMATE|FFTW_PRESERVE_INPUT); }); } @@ -88,35 +88,46 @@ void Aux::FftwDFT::inv1d(const complex_t* in, complex_t* out, int stride) const static const int dir = FFTW_BACKWARD; auto src = pval_cast(in); auto dst = pval_cast(out); - doit(mutex, plans, dir, src, dst, stride, 0, [&]( ) { + doit(mutex, plans, dir, src, dst, 0, stride, [&]( ) { return fftwf_plan_dft_1d(stride, src, dst, dir, FFTW_ESTIMATE|FFTW_PRESERVE_INPUT); }); + + // reverse normalization + for (int ind=0; ind #include using namespace WireCell; -using namespace WireCell::Aux; -void test_1d_imp(IDFT::pointer trans) +using real_t = float; +using RV = std::vector; +using complex_t = std::complex; +using CV = std::vector; + +void test_1d(IDFT::pointer dft) { - realvec_t rimp(64, 0); + RV rimp(64, 0); rimp[1] = 1.0; - auto cimp = dft(trans, rimp); + + auto cimp = Aux::fwd(dft, Waveform::complex(rimp)); for (auto c : cimp) { std::cerr << c << " "; } std::cerr << "\n"; + + RV rimp2 = Waveform::real(Aux::inv(dft, cimp)); + for (auto r : rimp2) { + std::cerr << r << " "; + } + std::cerr << "\n"; + for (int ind=0; ind<64; ++ind) { + if (ind == 1) { + assert(std::abs(rimp2[ind]-1.0) < 1e-6); + continue; + } + assert(std::abs(rimp2[ind]) < 1e-6); + } +} + +using FA = Eigen::Array; + +void test_2d(IDFT::pointer dft) +{ + const int nrows=16; + const int ncols=8; + FA r = FA::Zero(nrows, ncols); + r(10,1) = 1.0; + std::cerr << r << std::endl; + auto c = Aux::fwd(dft, r.cast()); + std::cerr << c << std::endl; + FA r2 = Aux::inv(dft, c).real(); + std::cerr << r2 << std::endl; + for (int irow=0; irow +void dump(std::string name, const array_type& arr) +{ + std::cerr << name << ":(" << arr.rows() << "," << arr.cols() << ") row-major:" << arr.IsRowMajor << "\n"; +} + +void test_2d_transpose(IDFT::pointer dft) +{ + const int nrows=16; + const int ncols=8; + + FA r = FA::Zero(nrows, ncols); // shape:(16,8) + dump("r", r); + + // do not remove the auto in this next line + auto rt = r.transpose(); // shape:(8,16) + dump("rt", rt); + rt(1,10) = 1.0; + + auto c = Aux::fwd(dft, rt.cast()); + dump("c", c); + + auto r2 = Aux::inv(dft, c).real(); + dump("r2",r2); + + // transpose access + const int nrowst = r2.rows(); + const int ncolst = r2.cols(); + + for (int irow=0; irow(); + auto dft = std::make_shared(); - test_1d_imp(trans); + test_1d(dft); + test_2d(dft); + test_2d_transpose(dft); return 0; } diff --git a/aux/test/test_idft.cxx b/aux/test/test_idft.cxx index 97f85b14d..73bc95df2 100644 --- a/aux/test/test_idft.cxx +++ b/aux/test/test_idft.cxx @@ -31,8 +31,8 @@ void test_2d_zero(IDFT::pointer dft, int size = 1024) std::vector inter(stride*nstrides,0); std::vector freq(stride*nstrides,0); - dft->fwd2d(inter.data(), freq.data(), stride, nstrides); - dft->inv2d(freq.data(), inter.data(), stride, nstrides); + dft->fwd2d(inter.data(), freq.data(), nstrides, stride); + dft->inv2d(freq.data(), inter.data(), nstrides, stride); auto tot = Waveform::sum(inter); assert(std::real(tot) == 0); @@ -48,8 +48,8 @@ void fwdrev(IDFT::pointer dft, int id, int ntimes, int size) while (ntimes) { //std::cerr << ntimes << "\n"; - dft->fwd2d(inter.data(), freq.data(), stride, nstrides); - dft->inv2d(freq.data(), inter.data(), stride, nstrides); + dft->fwd2d(inter.data(), freq.data(), nstrides, stride); + dft->inv2d(freq.data(), inter.data(), nstrides, stride); --ntimes; auto tot = Waveform::sum(inter); diff --git a/aux/test/test_tensor_tools.cxx b/aux/test/test_tensor_tools.cxx new file mode 100644 index 000000000..933a777e7 --- /dev/null +++ b/aux/test/test_tensor_tools.cxx @@ -0,0 +1,102 @@ +#include "WireCellAux/TensorTools.h" +#include "WireCellAux/SimpleTensor.h" + + +#include +#include + +using real_t = float; +using RV = std::vector; +using complex_t = std::complex; +using CV = std::vector; +using RT = WireCell::Aux::SimpleTensor; +using CT = WireCell::Aux::SimpleTensor; + +// test fodder +const RV real_vector{0,1,2,3,4,5}; +const RV real_vector_cw{0,3,1,4,2,5}; +const CV complex_vector{{0,0},{1,1},{2,2},{3,3},{4,4},{5,5}}; +const WireCell::ITensor::shape_t shape{2,3}; + +using namespace WireCell; + +void test_is_type() +{ + auto rt = std::make_shared(shape, real_vector.data()); + assert (Aux::is_type(rt)); + assert (!Aux::is_type(rt)); +} + +void test_is_row_major() +{ + // ST actually does not let us do anything but C-order/row-major + auto rm = std::make_shared(shape, real_vector.data()); + assert(Aux::is_row_major(rm)); +} + +template +void assert_equal(const VectorType& v1, const VectorType& v2) +{ + assert(v1.size() == v2.size()); + for (size_t ind=0; ind(shape, real_vector.data()); + auto ct = std::make_shared(shape, complex_vector.data()); + auto got_rt = Aux::asvec(rt); + auto got_ct = Aux::asvec(ct); + assert_equal(real_vector, got_rt); + assert_equal(complex_vector, got_ct); + + try { + auto oops = Aux::asvec(rt); + } + catch (ValueError& err) { + } +} + +void test_asarray() +{ + // as array 2x2: (1d,2d) x (rw,cw) + + // make mutable copy to test that TT returns a copy + RV my_vec(real_vector.begin(), real_vector.end()); + + // test 2d + auto rt = std::make_shared(shape, my_vec.data()); + auto ra = Aux::asarray(rt); + auto shape = rt->shape(); + for (size_t irow = 0; irow < shape[0]; ++irow) { + for (size_t icol = 0; icol < shape[1]; ++icol) { + assert(ra(irow, icol) == my_vec[irow*shape[1] + icol]); + } + } + + // test 1d + const WireCell::ITensor::shape_t shape1d{6,}; + auto rt1d = std::make_shared(shape1d, my_vec.data()); + auto ra1d = Aux::asarray(rt1d); + for (size_t ind = 0; ind < shape[0]; ++ind) { + assert(ra1d(ind) == my_vec[ind]); + } + + // Assure the internal use of Eigen::Map leads to a copy on return + my_vec[0] = 42; + assert(ra(0,0) == 0); + assert(ra1d(0) == 0); +} + +int main() +{ + test_is_type(); + test_is_row_major(); + test_asvec(); + test_asarray(); + + return 0; +} diff --git a/iface/inc/WireCellIface/IDFT.h b/iface/inc/WireCellIface/IDFT.h index dd1d39e8c..448c04858 100644 --- a/iface/inc/WireCellIface/IDFT.h +++ b/iface/inc/WireCellIface/IDFT.h @@ -1,5 +1,19 @@ /** Interface to perform discrete single-precision Fourier transforms. + + Note, implementations MUST NOT normalize forward transforms and + MUST normalize reverse/inverse transforms by 1/n where n is the + number of elements in the 1D array being reverse transformed. + + The number "stride" describes how many elements of the array are + contiguous. For "C-order" aka row-major ordering of 2D arrays, + stride is the size of a row, aka number of columns. + + The number "nstrides" describe how many arrays of length "stride" + are placed end-to-end in the memory. For "C-order" aka row-major + ordering of 2D arrays, the "nstrides" counts the size of the + columns, aka the number of rows. With this ordering, the + (nstrides, stride) pair maps to the usual (nrows, ncols). */ #ifndef WIRECELL_IDFT @@ -34,20 +48,20 @@ namespace WireCell { virtual void fwd1b(const complex_t* in, complex_t* out, - int stride, int nstrides) const; + int nstrides, int stride) const; virtual void inv1b(const complex_t* in, complex_t* out, - int stride, int nstrides) const; + int nstrides, int stride) const; // 2D, transform both dimensions virtual void fwd2d(const complex_t* in, complex_t* out, - int stride, int nstrides) const = 0; + int nstrides, int stride) const = 0; virtual void inv2d(const complex_t* in, complex_t* out, - int stride, int nstrides) const = 0; + int nstrides, int stride) const = 0; }; } diff --git a/iface/inc/WireCellIface/ITensor.h b/iface/inc/WireCellIface/ITensor.h index 5154b68fd..23bf52964 100644 --- a/iface/inc/WireCellIface/ITensor.h +++ b/iface/inc/WireCellIface/ITensor.h @@ -19,7 +19,12 @@ namespace WireCell { public: /// Shape gives size of each dimension. Size of shape give Ndim. typedef std::vector shape_t; - /// Storage order. Empty implies C order. + /// Storage order. Empty implies C order. If non-empty the + /// vector holds the "majority" of the dimension. C-order + /// implies a vector of {1,0} which means if the array is + /// accessed as array[a][b] then "b" is most major and "a" is + /// next most major. Ie, row-major. A fortran order would be + /// given as {0,1}. typedef std::vector order_t; /// The type of the element. diff --git a/iface/src/IDFT.cxx b/iface/src/IDFT.cxx index 2f9ee3543..c0a62655c 100644 --- a/iface/src/IDFT.cxx +++ b/iface/src/IDFT.cxx @@ -9,7 +9,7 @@ IDFT::~IDFT() {} // with some GPU FFTs, override these methods! void IDFT::fwd1b(const complex_t* in, complex_t* out, - int stride, int nstrides) const + int nstrides, int stride) const { for (int istride=0; istride +#include +#include + +using real_t = float; +using RV = std::vector; +using complex_t = std::complex; +using CV = std::vector; +const RV real_vector{0,1,2,3,4,5}; +const RV real_vector_cw{0,3,1,4,2,5}; +const CV complex_vector{{0,0},{1,10},{2,20},{3,30},{4,40},{5,50}}; + +using RA = Eigen::Array; +using RARM = Eigen::Array; +using CA = Eigen::Array; +using CARM = Eigen::Array; + +int main() +{ + RA ra = Eigen::Map((real_t*)real_vector.data(), 2, 3); + CA ca = Eigen::Map((complex_t*)complex_vector.data(), 2, 3); + + CA ra2c = ra.cast(); + RA ca2r = ca.real(); + + for (int irow = 0; irow<2; ++irow) { + for (int icol = 0; icol<3; ++icol) { + int ind = irow*3 + icol; + complex_t c = ra2c(irow, icol); + real_t r = c.real(); + real_t r2 = ca2r(irow, icol); + real_t rwant = real_vector[ind]; + complex_t cwant = complex_vector[ind]; + + std::cerr << ind << ": c=" << c << " r=" << r << " r2=" << r2 << " rwant=" << rwant << " cwant=" << cwant << "\n"; + assert(c.imag() == 0.0); + assert(r==rwant); + assert(r2==rwant); + + } + } + return 0; +} diff --git a/util/test/test_eigen_rowcol.cxx b/util/test/test_eigen_rowcol.cxx new file mode 100644 index 000000000..aec1fd763 --- /dev/null +++ b/util/test/test_eigen_rowcol.cxx @@ -0,0 +1,150 @@ +#include +#include + +using DEFM = Eigen::Array; // should be ColMajor +using COLM = Eigen::Array; +using ROWM = Eigen::Array; + +COLM get_mapped_cw() +{ + std::vector col_major{11,21,12,22,13,23}; + Eigen::Map ret(col_major.data(), 2,3); + return ret; +} +ROWM get_mapped_rw() +{ + std::vector row_major{11, 12, 13, 21, 22, 23}; + Eigen::Map ret(row_major.data(), 2,3); + return ret; +} + + +COLM get_colwise() +{ + COLM ret(2,3); + for (int major=0; major<2; ++major) { + for (int minor=0; minor<3; ++minor) { + ret(major,minor) = (major+1)*10 + minor+1; + } + } + return ret; +} +ROWM get_rowwise() +{ + ROWM ret(2,3); + for (int major=0; major<2; ++major) { + for (int minor=0; minor<3; ++minor) { + ret(major,minor) = (major+1)*10 + minor+1; + } + } + return ret; +} + +void dump_def(DEFM arr) +{ + std::cout << "DEFM" << "("< Date: Mon, 8 Nov 2021 16:31:54 -0500 Subject: [PATCH 09/46] Elaborate on comments --- iface/inc/WireCellIface/ITensor.h | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/iface/inc/WireCellIface/ITensor.h b/iface/inc/WireCellIface/ITensor.h index 23bf52964..d1cbf6adb 100644 --- a/iface/inc/WireCellIface/ITensor.h +++ b/iface/inc/WireCellIface/ITensor.h @@ -19,12 +19,22 @@ namespace WireCell { public: /// Shape gives size of each dimension. Size of shape give Ndim. typedef std::vector shape_t; + /// Storage order. Empty implies C order. If non-empty the /// vector holds the "majority" of the dimension. C-order /// implies a vector of {1,0} which means if the array is - /// accessed as array[a][b] then "b" is most major and "a" is - /// next most major. Ie, row-major. A fortran order would be - /// given as {0,1}. + /// accessed as array[a][b] "axis" 0 (indexed by "a") is the + /// "major index" and "axis" 1 (indexed by "b") is the "minor + /// index". It is thus "row-major" ordering as the major + /// index counts rows. An array in fortran-order + /// (column-major order) would be given as {0,1}. + /// + /// A note as this can be confusing: The "logical" rows and + /// columns, eg when used in an Eigen array are independent + /// from memory order. An Eigen array is always indexed as + /// arr(r,c). Storage order only matters when, well, you + /// access the array storage such as from Eigen array's + /// .data() method - and indeed ITensor::data(). typedef std::vector order_t; /// The type of the element. From cd3369d3385300e58f88de3e6388348fd5ac697e Mon Sep 17 00:00:00 2001 From: Brett Viren Date: Mon, 8 Nov 2021 16:31:58 -0500 Subject: [PATCH 10/46] Flesh out and test DftTools interface --- aux/inc/WireCellAux/DftTools.h | 20 +++++++++--- aux/src/DftTools.cxx | 58 ++++++++++++++++++++++++++++++++++ aux/test/test_dfttools.cxx | 35 +++++++++++++++++++- 3 files changed, 107 insertions(+), 6 deletions(-) diff --git a/aux/inc/WireCellAux/DftTools.h b/aux/inc/WireCellAux/DftTools.h index 10f2d8cbf..bf4894291 100644 --- a/aux/inc/WireCellAux/DftTools.h +++ b/aux/inc/WireCellAux/DftTools.h @@ -47,11 +47,21 @@ namespace WireCell::Aux { dft_array_t fwd(IDFT::pointer dft, const dft_array_t& arr); dft_array_t inv(IDFT::pointer dft, const dft_array_t& arr); - // Transform one dimesions. For example axis=0 transforms each - // logical row of the Eigen array so that column=0 of each row - // would hold the frequency=0 component of each row's spectrum. - // array_xxc fwd(IDFT::pointer dft, const array_xxc& arr, int axis); - // array_xxc inv(IDFT::pointer dft, const array_xxc& arr, int axis); + // Transform a 2D array along one axis. + // + // The axis identifies the logical array "dimension" over which + // the transform is applied. For example, axis=1 means the + // transforms are applied along columns (ie, on a per-row basis). + // Note: this is the same convention as held by numpy.fft. + // + // The axis is interpreted in the "logical" sense Eigen arrays + // indexed as array(irow, icol). Ie, the dimension traversing + // rows is axis 0 and the dimension traversing columns is axis 1. + // Note: internal storage order of an Eigen array may differ from + // the logical order and indeed that of the array template type + // order. Neither is pertinent in setting the axis. + dft_array_t fwd(IDFT::pointer dft, const dft_array_t& arr, int axis); + dft_array_t inv(IDFT::pointer dft, const dft_array_t& arr, int axis); } diff --git a/aux/src/DftTools.cxx b/aux/src/DftTools.cxx index 5ef53e166..7297d8806 100644 --- a/aux/src/DftTools.cxx +++ b/aux/src/DftTools.cxx @@ -55,3 +55,61 @@ Aux::dft_array_t Aux::inv(IDFT::pointer dft, const Aux::dft_array_t& arr) dft->inv2d(in_data, out_data, nstrides, stride); }); } + +#include // debug + +template +Aux::dft_array_t doit1b(const Aux::dft_array_t& arr, int axis, trans func) +{ + // We must provide a flat array with storage order such with + // logical axis-major ordering. + const Aux::complex_t* in_data = arr.data(); + const int nrows = arr.rows(); // "logical" + const int ncols = arr.cols(); // shape + + std::cerr << "nrows="<(out_vec.data(), nrows, ncols); + } + return Eigen::Map(out_vec.data(), nrows, ncols); + } + + // Either we have row-major and want column-major storage order or + // vice versa. + + // Here, we must copy and not use "auto" to get actual storage + // order transpose and avoid the IsRowMajor flip optimization. + COLM flipped = arr.transpose(); + COLM got = doit1b(flipped, (axis+1)%2, func); + return got.transpose(); +} + +Aux::dft_array_t Aux::fwd(IDFT::pointer dft, const Aux::dft_array_t& arr, int axis) +{ + return doit1b(arr, axis, + [&](const complex_t* in_data, + complex_t* out_data, + int nstrides, int stride) { + dft->fwd1b(in_data, out_data, nstrides, stride); + }); +} + +Aux::dft_array_t Aux::inv(IDFT::pointer dft, const Aux::dft_array_t& arr, int axis) +{ + return doit1b(arr, axis, + [&](const complex_t* in_data, + complex_t* out_data, + int nstrides, int stride) { + dft->inv1b(in_data, out_data, nstrides, stride); + }); +} diff --git a/aux/test/test_dfttools.cxx b/aux/test/test_dfttools.cxx index c5bf85930..e506da80f 100644 --- a/aux/test/test_dfttools.cxx +++ b/aux/test/test_dfttools.cxx @@ -107,6 +107,36 @@ void test_2d_transpose(IDFT::pointer dft) } +void test_1b(IDFT::pointer dft, int axis) +{ + const int nrows=8; + const int ncols=4; + FA r = FA::Zero(nrows, ncols); + r(6,1) = 1.0; + dump("impulse", r); + std::cerr << r << std::endl; + auto c = Aux::fwd(dft, r.cast(), axis); + dump("spectra", c); + if (axis==0) { + + } + std::cerr << c << std::endl; +} +void test_1bt(IDFT::pointer dft, int axis) +{ + const int nrows=8; + const int ncols=4; + FA r = FA::Zero(nrows, ncols); + r(6,1) = 1.0; + auto rc = r.cast(); + auto rct = rc.transpose(); + dump("impulse.T", rct); + std::cerr << rct << std::endl; + auto c = Aux::fwd(dft, rct, axis); + dump("spectra", c); + std::cerr << c << std::endl; +} + int main() { auto dft = std::make_shared(); @@ -114,6 +144,9 @@ int main() test_1d(dft); test_2d(dft); test_2d_transpose(dft); - + test_1b(dft, 0); + test_1b(dft, 1); + test_1bt(dft, 0); + test_1bt(dft, 1); return 0; } From a83cba82e38e6e8bcfc579e35c64ffbcfc0afa6e Mon Sep 17 00:00:00 2001 From: Brett Viren Date: Mon, 15 Nov 2021 16:00:39 -0500 Subject: [PATCH 11/46] Add axis, transpose, tests --- aux/inc/WireCellAux/DftTools.h | 2 +- aux/inc/WireCellAux/FftwDFT.h | 31 +++-- aux/src/DftTools.cxx | 62 ++++++---- aux/src/FftwDFT.cxx | 209 ++++++++++++++++++++++++++++----- aux/test/test_idft.cxx | 155 +++++++++++++++++++++++- iface/inc/WireCellIface/IDFT.h | 123 ++++++++++++++----- iface/src/IDFT.cxx | 84 +++++++++++-- 7 files changed, 564 insertions(+), 102 deletions(-) diff --git a/aux/inc/WireCellAux/DftTools.h b/aux/inc/WireCellAux/DftTools.h index bf4894291..20272bc3f 100644 --- a/aux/inc/WireCellAux/DftTools.h +++ b/aux/inc/WireCellAux/DftTools.h @@ -41,7 +41,7 @@ namespace WireCell::Aux { using dft_array_t = Eigen::ArrayXXcf; // 2D with Eigen arrays. Use eg arr.cast() to provde - // from real or arr.cast() to convert result to real. + // from real or arr.real()() to convert result to real. // Transform both dimesions. dft_array_t fwd(IDFT::pointer dft, const dft_array_t& arr); diff --git a/aux/inc/WireCellAux/FftwDFT.h b/aux/inc/WireCellAux/FftwDFT.h index 365190f0c..f26b5bcf9 100644 --- a/aux/inc/WireCellAux/FftwDFT.h +++ b/aux/inc/WireCellAux/FftwDFT.h @@ -6,7 +6,12 @@ namespace WireCell::Aux { /** - FftwDFT provides IDFT based on FFTW3. + The FftwDFT component provides IDFT based on FFTW3. + + All instances share a common thread-safe plan cache. There is + no benefit to using more than one instance in a process. + + See IDFT.h for important comments. */ class FftwDFT : public IDFT { public: @@ -18,23 +23,33 @@ namespace WireCell::Aux { virtual void fwd1d(const complex_t* in, complex_t* out, - int stride) const; + int size) const; virtual void inv1d(const complex_t* in, complex_t* out, - int stride) const; + int size) const; - // batched 1D ("1b") - rely on base implementation + virtual + void fwd1b(const complex_t* in, complex_t* out, + int nrows, int ncols, int axis) const; - // 2d + virtual + void inv1b(const complex_t* in, complex_t* out, + int nrows, int ncols, int axis) const; virtual void fwd2d(const complex_t* in, complex_t* out, - int nstrides, int stride) const; + int nrows, int ncols) const; virtual void inv2d(const complex_t* in, complex_t* out, - int nstrides, int stride) const; - + int nrows, int ncols) const; + + virtual + void transpose(const scalar_t* in, scalar_t* out, + int nrows, int ncols) const; + virtual + void transpose(const complex_t* in, complex_t* out, + int nrows, int ncols) const; }; } diff --git a/aux/src/DftTools.cxx b/aux/src/DftTools.cxx index 7297d8806..905cbcfa2 100644 --- a/aux/src/DftTools.cxx +++ b/aux/src/DftTools.cxx @@ -16,20 +16,20 @@ using COLM = Eigen::Array Aux::dft_array_t doit(const Aux::dft_array_t& arr, trans func) { - // Nominally, memory is in column-major order + // Nominally, eigen storage memory is in column-major order const Aux::complex_t* in_data = arr.data(); - int stride = arr.rows(); - int nstrides = arr.cols(); + int ncols = arr.rows(); + int nrows = arr.cols(); // except when it isn't bool flipped = arr.IsRowMajor; if (flipped) { - stride = arr.cols(); - nstrides = arr.rows(); + ncols = arr.cols(); + nrows = arr.rows(); } - Aux::dft_vector_t out_vec(nstrides*stride); - func(in_data, out_vec.data(), nstrides, stride); + Aux::dft_vector_t out_vec(nrows*ncols); + func(in_data, out_vec.data(), nrows, ncols); if (flipped) { return Eigen::Map(out_vec.data(), arr.rows(), arr.cols()); @@ -42,8 +42,8 @@ Aux::dft_array_t Aux::fwd(IDFT::pointer dft, const Aux::dft_array_t& arr) { return doit(arr, [&](const complex_t* in_data, complex_t* out_data, - int nstrides, int stride) { - dft->fwd2d(in_data, out_data, nstrides, stride); + int nrows, int ncols) { + dft->fwd2d(in_data, out_data, nrows, ncols); }); } @@ -51,8 +51,8 @@ Aux::dft_array_t Aux::inv(IDFT::pointer dft, const Aux::dft_array_t& arr) { return doit(arr, [&](const complex_t* in_data, complex_t* out_data, - int nstrides, int stride) { - dft->inv2d(in_data, out_data, nstrides, stride); + int nrows, int ncols) { + dft->inv2d(in_data, out_data, nrows, ncols); }); } @@ -94,22 +94,38 @@ Aux::dft_array_t doit1b(const Aux::dft_array_t& arr, int axis, trans func) return got.transpose(); } +// Implementation notes for fwd()/inv(): +// +// - We make an initial copy to get rid of any potential IsRowMajor +// optimization/confusion over storage order. This suffers a copy +// but we need to allocate return anyways. +// +// - We then have column-wise storage order but IDFT assumes row-wise +// - so we reverse (nrows, ncols) and meaning of axis. + Aux::dft_array_t Aux::fwd(IDFT::pointer dft, const Aux::dft_array_t& arr, int axis) { - return doit1b(arr, axis, - [&](const complex_t* in_data, - complex_t* out_data, - int nstrides, int stride) { - dft->fwd1b(in_data, out_data, nstrides, stride); - }); + Aux::dft_array_t ret = arr; + dft->fwd1b(ret.data(), ret.data(), ret.cols(), ret.rows(), !axis); + return ret; + + // return doit1b(arr, axis, + // [&](const complex_t* in_data, + // complex_t* out_data, + // int nrows, int ncols) { + // dft->fwd1b(in_data, out_data, nrows, ncols); + // }); } Aux::dft_array_t Aux::inv(IDFT::pointer dft, const Aux::dft_array_t& arr, int axis) { - return doit1b(arr, axis, - [&](const complex_t* in_data, - complex_t* out_data, - int nstrides, int stride) { - dft->inv1b(in_data, out_data, nstrides, stride); - }); + Aux::dft_array_t ret = arr; + dft->inv1b(ret.data(), ret.data(), ret.cols(), ret.rows(), !axis); + return ret; + // return doit1b(arr, axis, + // [&](const complex_t* in_data, + // complex_t* out_data, + // int nrows, int ncols) { + // dft->inv1b(in_data, out_data, nrows, ncols); + // }); } diff --git a/aux/src/FftwDFT.cxx b/aux/src/FftwDFT.cxx index f4a214760..1289b8f99 100644 --- a/aux/src/FftwDFT.cxx +++ b/aux/src/FftwDFT.cxx @@ -16,15 +16,25 @@ using plan_type = fftwf_plan; using plan_map_t = std::unordered_map; using plan_val_t = fftwf_complex; +// Make a key by which a plan is known. dir should be FFTW_FORWARD or +// FFTW_BACKWARD and "axis" is -1 for all or in {0,1} for one of 2D. +// For 1D, use the default axis=-1. +// +// Imp note: The key is slightly over-specified as we keep one +// independent cache for each of the six methods. The "dir" is +// thus redundant. static -plan_key_t make_key(bool inverse, const void * src, void * dst, int n0, int n1) +plan_key_t make_key(const void * src, void * dst, int nrows, int ncols, int dir, int axis=-1) { + ++axis; // need three positive values, default is both axis + bool inverse = dir == FFTW_BACKWARD; bool inplace = (dst==src); bool aligned = ( (reinterpret_cast(src)&15) | (reinterpret_cast(dst)&15) ) == 0; - int64_t key = ( ( (((int64_t)n0) << 30)|(n1<<3 ) | (inverse<<2) | (inplace<<1) | aligned ) << 1 ) + 1; + int64_t key = ( ( (((int64_t)nrows) << 32)| (ncols<<5 ) | (axis<<3) | (inverse<<2) | (inplace<<1) | aligned ) << 1 ) + 1; return key; } +// Look up a plan by key or return NULL static plan_type get_plan(std::shared_mutex& mutex, plan_map_t& plans, plan_key_t key) { @@ -39,12 +49,16 @@ plan_type get_plan(std::shared_mutex& mutex, plan_map_t& plans, plan_key_t key) // #include // debugging -template -void doit(std::shared_mutex& mutex, plan_map_t& plans, - int fwdrev, plan_val_t* src, plan_val_t* dst, int nstrides, int stride, - planner_function make_plan) +using planner_function = std::function; + +// This wraps plan lookup, possible plan creation and subsequent plan +// execution so that we get thread-safe plan caching. +template +void doit(std::shared_mutex& mutex, plan_map_t& plans, plan_key_t key, + ValueType* src, ValueType* dst, + planner_function make_plan, + std::function exec_plan) { - auto key = make_key(fwdrev == FFTW_BACKWARD, src, dst, stride, nstrides); auto plan = get_plan(mutex, plans, key); if (!plan) { std::unique_lock lock(mutex); @@ -59,7 +73,8 @@ void doit(std::shared_mutex& mutex, plan_map_t& plans, plan = it->second; } } - fftwf_execute_dft(plan, src, dst); + //fftwf_execute_dft(plan, src, dst); + exec_plan(plan, src, dst); } @@ -70,65 +85,203 @@ plan_val_t* pval_cast( const IDFT::complex_t * p) } -void Aux::FftwDFT::fwd1d(const complex_t* in, complex_t* out, int stride) const +void Aux::FftwDFT::fwd1d(const complex_t* in, complex_t* out, int ncols) const { static std::shared_mutex mutex; static plan_map_t plans; static const int dir = FFTW_FORWARD; auto src = pval_cast(in); auto dst = pval_cast(out); - doit(mutex, plans, dir, src, dst, 0, stride, [&]( ) { - return fftwf_plan_dft_1d(stride, src, dst, dir, FFTW_ESTIMATE|FFTW_PRESERVE_INPUT); - }); + auto key = make_key(src, dst, 1, ncols, dir); + doit(mutex, plans, key, src, dst, [&]( ) { + return fftwf_plan_dft_1d(ncols, src, dst, dir, FFTW_ESTIMATE|FFTW_PRESERVE_INPUT); + }, fftwf_execute_dft); } -void Aux::FftwDFT::inv1d(const complex_t* in, complex_t* out, int stride) const +void Aux::FftwDFT::inv1d(const complex_t* in, complex_t* out, int ncols) const { static std::shared_mutex mutex; static plan_map_t plans; static const int dir = FFTW_BACKWARD; auto src = pval_cast(in); auto dst = pval_cast(out); - doit(mutex, plans, dir, src, dst, 0, stride, [&]( ) { - return fftwf_plan_dft_1d(stride, src, dst, dir, FFTW_ESTIMATE|FFTW_PRESERVE_INPUT); - }); + auto key = make_key(src, dst, 1, ncols, dir); - // reverse normalization - for (int ind=0; ind(mutex, plans, key, src, dst, [&]( ) { + return fftwf_plan_dft_1d(ncols, src, dst, dir, FFTW_ESTIMATE|FFTW_PRESERVE_INPUT); + }, fftwf_execute_dft); + + // Apply 1/n normalization + for (int ind=0; ind(mutex, plans, key, src, dst, [&]( ) { + return plan_1b(src, dst, nrows, ncols, dir, axis); + }, fftwf_execute_dft); +} + + +void Aux::FftwDFT::inv1b(const complex_t* in, complex_t* out, int nrows, int ncols, int axis) const +{ + static std::shared_mutex mutex; + static plan_map_t plans; + static const int dir = FFTW_BACKWARD; + auto src = pval_cast(in); + auto dst = pval_cast(out); + auto key = make_key(src, dst, nrows, ncols, dir, axis); + + doit(mutex, plans, key, src, dst, [&]( ) { + return plan_1b(src, dst, nrows, ncols, dir, axis); + }, fftwf_execute_dft); + + // 1/n normalization + const int norm = axis ? ncols : nrows; + const int ntot = ncols*nrows; + for (int ind=0; ind(mutex, plans, key, src, dst, [&]( ) { + return fftwf_plan_dft_2d(ncols, nrows, src, dst, dir, FFTW_ESTIMATE|FFTW_PRESERVE_INPUT); + }, fftwf_execute_dft); } -void Aux::FftwDFT::inv2d(const complex_t* in, complex_t* out, int nstrides, int stride) const +void Aux::FftwDFT::inv2d(const complex_t* in, complex_t* out, int nrows, int ncols) const { static std::shared_mutex mutex; static plan_map_t plans; static const int dir = FFTW_BACKWARD; auto src = pval_cast(in); auto dst = pval_cast(out); - doit(mutex, plans, dir, src, dst, nstrides, stride, [&]( ) { - return fftwf_plan_dft_2d(stride, nstrides, src, dst, dir, FFTW_ESTIMATE|FFTW_PRESERVE_INPUT); - }); + auto key = make_key(src, dst, nrows, ncols, dir); + doit(mutex, plans, key, src, dst, [&]( ) { + return fftwf_plan_dft_2d(ncols, nrows, src, dst, dir, FFTW_ESTIMATE|FFTW_PRESERVE_INPUT); + }, fftwf_execute_dft); // reverse normalization - const int ntot = stride*nstrides; + const int ntot = ncols*nrows; for (int ind=0; ind(mutex, plans, key, src, dst, [&]( ) { + return transpose_plan_complex(src, dst, nrows, ncols); + }, fftwf_execute_dft); +} + +static +plan_type transpose_plan_real(float *in, float *out, int rows, int cols) +{ + const unsigned flags = FFTW_ESTIMATE; /* other flags are possible */ + fftw_iodim howmany_dims[2]; + + howmany_dims[0].n = rows; + howmany_dims[0].is = cols; + howmany_dims[0].os = 1; + + howmany_dims[1].n = cols; + howmany_dims[1].is = 1; + howmany_dims[1].os = rows; + + return fftwf_plan_guru_r2r(/*rank=*/ 0, /*dims=*/ NULL, + /*howmany_rank=*/ 2, howmany_dims, + in, out, /*kind=*/ NULL, flags); +} +void Aux::FftwDFT::transpose(const scalar_t* in, scalar_t* out, + int nrows, int ncols) const +{ + static std::shared_mutex mutex; + static plan_map_t plans; + static const int dir = 0; + auto src = const_cast(in); + auto dst = out; + auto key = make_key(src, dst, nrows, ncols, dir); + doit(mutex, plans, key, src, dst, [&]( ) { + return transpose_plan_real(src, dst, nrows, ncols); + }, fftwf_execute_r2r); +} + Aux::FftwDFT::FftwDFT() { } diff --git a/aux/test/test_idft.cxx b/aux/test/test_idft.cxx index 73bc95df2..cfd109a8b 100644 --- a/aux/test/test_idft.cxx +++ b/aux/test/test_idft.cxx @@ -8,10 +8,36 @@ #include #include #include +#include #include using namespace WireCell; +const float eps = 1e-8; + +static void assert_impulse_at_index(const std::vector& vec, size_t index=0) +{ + const size_t size = vec.size(); + auto tot = Waveform::sum(vec); + assert(std::abs(std::real(tot) - 1.0) < eps); + assert(std::abs(std::real(vec[index]) - 1.0) < eps); + assert(std::abs(std::imag(tot)) < eps); + assert(std::abs(std::imag(vec[index])) < eps); + for (size_t ind=0; ind& vec, IDFT::scalar_t val = 1.0) +{ + const auto size = vec.size(); + auto tot = Waveform::sum(vec); + assert(std::abs(std::abs(tot) - val*size) < eps); + for (const auto& v : vec) { + assert(std::abs(std::abs(v) - val) < eps); + } +} static void test_1d_zero(IDFT::pointer dft, int size = 1024) @@ -19,11 +45,25 @@ void test_1d_zero(IDFT::pointer dft, int size = 1024) std::vector inter(size,0), freq(size,0); dft->fwd1d(inter.data(), freq.data(), inter.size()); + assert_flat_value(freq, 0); dft->inv1d(freq.data(), inter.data(), freq.size()); + assert_flat_value(inter, 0); +} +static +void test_1d_impulse(IDFT::pointer dft, int size=1024) +{ + std::vector inter(size,0), freq(size,0), back(size,0); + inter[0] = 1.0; + + dft->fwd1d(inter.data(), freq.data(), freq.size()); + assert_flat_value(freq); - auto tot = Waveform::sum(inter); - assert(std::real(tot) == 0); + dft->inv1d(freq.data(), back.data(), back.size()); + assert_impulse_at_index(back); } + + + static void test_2d_zero(IDFT::pointer dft, int size = 1024) { @@ -32,10 +72,71 @@ void test_2d_zero(IDFT::pointer dft, int size = 1024) std::vector freq(stride*nstrides,0); dft->fwd2d(inter.data(), freq.data(), nstrides, stride); + assert_flat_value(inter, 0); dft->inv2d(freq.data(), inter.data(), nstrides, stride); + assert_flat_value(freq, 0); +} +static +void test_2d_impulse(IDFT::pointer dft, int nrows=128, int ncols=128) +{ + const int size = nrows*ncols; + std::vector inter(size,0), freq(size,0), back(size,0); + inter[0] = 1.0; + dft->fwd2d(inter.data(), freq.data(), nrows, ncols); + assert_flat_value(freq); + + dft->inv2d(freq.data(), back.data(), nrows, ncols); + assert_impulse_at_index(back); + +} + + +static void assert_on_axis(const std::vector& freq, + int axis, int nrows=128, int ncols=128) +{ + for (int irow=0; irow inter(size,0), freq(size,0), back(size,0); + inter[0] = 1.0; + dft->fwd1b(inter.data(), freq.data(), nrows, ncols, axis); + assert_on_axis(freq, axis, nrows, ncols); + dft->inv1b(freq.data(), back.data(), nrows, ncols, axis); + assert_impulse_at_index(back, 0); - auto tot = Waveform::sum(inter); - assert(std::real(tot) == 0); + std::vector inplace(size,0); + inplace[0] = 1.0; + dft->fwd1b(inplace.data(), inplace.data(), nrows, ncols, axis); + assert_on_axis(inplace, axis, nrows, ncols); + + std::vector inback(inplace.begin(), inplace.end()); + dft->inv1b(inback.data(), inback.data(), nrows, ncols, axis); + assert_impulse_at_index(inback, 0); } void fwdrev(IDFT::pointer dft, int id, int ntimes, int size) @@ -84,6 +185,42 @@ void test_2d_threads(IDFT::pointer dft, int nthreads, int nloops, int size = 102 << " " << dt1.count() << std::endl; } +template +void dump(ValueType* data, int nrows, int ncols, std::string msg="") +{ + std::cerr << msg << "("< +void test_2d_transpose(IDFT::pointer dft, int nrows, int ncols) +{ + std::vector arr(nrows*ncols); + std::iota(arr.begin(), arr.end(), 0); + + std::vector arr2(nrows*ncols, 0); + std::vector arr3(arr.begin(), arr.end()); + + dft->transpose(arr.data(), arr2.data(), nrows, ncols); + dft->transpose(arr3.data(), arr3.data(), nrows, ncols); + + for (int irow=0; irow(idft, 2, 8); + test_2d_transpose(idft, 8, 2); + test_2d_transpose(idft, 2, 8); + test_2d_transpose(idft, 8, 2); std::vector sizes = {128,256,512,1024}; for (auto size : sizes) { diff --git a/iface/inc/WireCellIface/IDFT.h b/iface/inc/WireCellIface/IDFT.h index 448c04858..e94afb391 100644 --- a/iface/inc/WireCellIface/IDFT.h +++ b/iface/inc/WireCellIface/IDFT.h @@ -1,21 +1,3 @@ -/** - Interface to perform discrete single-precision Fourier transforms. - - Note, implementations MUST NOT normalize forward transforms and - MUST normalize reverse/inverse transforms by 1/n where n is the - number of elements in the 1D array being reverse transformed. - - The number "stride" describes how many elements of the array are - contiguous. For "C-order" aka row-major ordering of 2D arrays, - stride is the size of a row, aka number of columns. - - The number "nstrides" describe how many arrays of length "stride" - are placed end-to-end in the memory. For "C-order" aka row-major - ordering of 2D arrays, the "nstrides" counts the size of the - columns, aka the number of rows. With this ordering, the - (nstrides, stride) pair maps to the usual (nrows, ncols). -*/ - #ifndef WIRECELL_IDFT #define WIRECELL_IDFT @@ -24,6 +6,77 @@ namespace WireCell { + /** + Interface to perform discrete Fourier transforms on arrays of + signal precision, complex floating point values. + + There are 6 DFT methods which are formed as the outer product + of two lists: + + - fwd, inv + - 1d, 1b, 2d + + The "fwd" methods provide forward transform, no normalization. + The "inv" methods provide reverse/inverse transform normalized + by 1/size. + + The 1d transforms take rank=1 / 1D arrays and perform a single + transform. + + The 2d transforms take rank=2 / 2D arrays and perform nrows of + transforms along rows and ncols of transforms along columns. + The order over which each dimension is transformed is + implementation-defined (and imaterial). + + The 1b transforms take rank=1 / 2D arrays and perform + transforms along a single dimension as determined by the value + of the "axis" parameter. An axis=1 means to perform nrows + transforms along rows. Note, this is the same convention + followed by numpy.fft functions. + + There is also a special rank=0 DFT on rank=2 arrays which is + more commonly known as a "matrix transpose". + + Requirements on implementations: + + - Forward transforms SHALL NOT apply normalization. + + - Reverse transforms SHALL apply 1/n normalization. + + - The arrays SHALL be assumed to follow C-ordering aka + row-major storage order. + + - Transform methods SHALL allow the input and output array + pointers to be identical. + + - The IDFT interface provides 1b methods implemented in terms + of 1d calls and a implementation MAY override these (for + example, if implementation can exploit batch optimization). + + - Implementation SHALL allow safe concurrent calls to methods + by different threads of execution. + + Requirement on callers. + + - Input and output arrays SHALL be pre-allocated and be sized + at least as large as indicated by accompanying size arguments. + + - Input and output arrays MUST either be non-overlapping in + memory or MUST be identical. + + Notes: + + - All arrays are of type single precision complex floating + point. Functions and methods to easily convert between the + two exist. + + - Eigen arrays are column-wise by default and so their + arr.data() method can not directly supply input to this + interface. Likewise, use of arr.transpose().data() may run + afowl of Eigen's IsRowMajor optimization flag. Copy your + default array in a Eigen::RowMajor array first or use IDFT + via Aux::DftTools functions. + */ class IDFT : public IComponent { public: virtual ~IDFT(); @@ -34,35 +87,43 @@ namespace WireCell { /// The type for the spectrum in each bin. using complex_t = std::complex; - // 1D + // 1d virtual - void fwd1d(const complex_t* in, complex_t* out, - int stride) const = 0; + void fwd1d(const complex_t* in, complex_t* out, int size) const = 0; virtual - void inv1d(const complex_t* in, complex_t* out, - int stride) const = 0; + void inv1d(const complex_t* in, complex_t* out, int size) const = 0; - // batched 1D ("1b") + // 1b virtual void fwd1b(const complex_t* in, complex_t* out, - int nstrides, int stride) const; + int nrows, int ncols, int axis) const; + virtual void inv1b(const complex_t* in, complex_t* out, - int nstrides, int stride) const; - + int nrows, int ncols, int axis) const; - // 2D, transform both dimensions + // 2d virtual void fwd2d(const complex_t* in, complex_t* out, - int nstrides, int stride) const = 0; + int nrows, int ncols) const = 0; virtual void inv2d(const complex_t* in, complex_t* out, - int nstrides, int stride) const = 0; - + int nrows, int ncols) const = 0; + + + // Fill "out" with the transpose of "in", may be in-place. + // The nrows/ncols refers to the shape of the input. + virtual + void transpose(const scalar_t* in, scalar_t* out, + int nrows, int ncols) const; + virtual + void transpose(const complex_t* in, complex_t* out, + int nrows, int ncols) const; + }; } diff --git a/iface/src/IDFT.cxx b/iface/src/IDFT.cxx index c0a62655c..58b1d335b 100644 --- a/iface/src/IDFT.cxx +++ b/iface/src/IDFT.cxx @@ -1,25 +1,95 @@ #include "WireCellIface/IDFT.h" +#include +#include // std::swap since c++11 + using namespace WireCell; IDFT::~IDFT() {} // Trivial default "batched" implementations. If your concrete // implementation provides some kind of "batch optimization", such as -// with some GPU FFTs, override these methods! +// with FFTW3's advanced interface or with some GPU FFT library, +// override these dumb methods for the win. void IDFT::fwd1b(const complex_t* in, complex_t* out, - int nstrides, int stride) const + int nrows, int ncols, int axis) const { - for (int istride=0; istridetranspose(in, out, nrows, ncols); + this->fwd1b(out, out, ncols, nrows, 1); + this->transpose(out, out, ncols, nrows); } } void IDFT::inv1b(const complex_t* in, complex_t* out, - int nstrides, int stride) const + int nrows, int ncols, int axis) const { - for (int istride=0; istridetranspose(in, out, nrows, ncols); + this->inv1b(out, out, ncols, nrows, 1); + this->transpose(out, out, ncols, nrows); } } + +// Trivial default transpose. Implementations, please override if you +// can offer something faster. + +template +void transpose_type(const ValueType* in, ValueType* out, + int nrows, int ncols) +{ + if (in != out) { + for (int irow=0; irow visited(size); + ValueType* first = out + size; + const ValueType* last = first + size; + ValueType* cycle = out; + while (++cycle != last) { + if (visited[cycle - first]) + continue; + int a = cycle - first; + do { + a = a == mn1 ? mn1 : (n * a) % mn1; + std::swap(*(first + a), *cycle); + visited[a] = true; + } while ((first + a) != cycle); + } + +} + + +void IDFT::transpose(const IDFT::scalar_t* in, IDFT::scalar_t* out, + int nrows, int ncols) const +{ + transpose_type(in, out, nrows, ncols); +} +void IDFT::transpose(const IDFT::complex_t* in, IDFT::complex_t* out, + int nrows, int ncols) const +{ + transpose_type(in, out, nrows, ncols); +} From e8e0625bff7284fcfe93d7cdbfca00c852f410a6 Mon Sep 17 00:00:00 2001 From: Brett Viren Date: Tue, 16 Nov 2021 11:47:00 -0500 Subject: [PATCH 12/46] Improve tests --- aux/test/test_dfttools.cxx | 131 ++++++++++++++++--------------------- aux/test/test_idft.cxx | 28 +------- 2 files changed, 58 insertions(+), 101 deletions(-) diff --git a/aux/test/test_dfttools.cxx b/aux/test/test_dfttools.cxx index e506da80f..3ee006a92 100644 --- a/aux/test/test_dfttools.cxx +++ b/aux/test/test_dfttools.cxx @@ -1,3 +1,5 @@ +#include "aux_test_dft_helpers.h" + #include "WireCellAux/DftTools.h" #include "WireCellAux/FftwDFT.h" #include "WireCellUtil/Waveform.h" @@ -6,79 +8,67 @@ #include using namespace WireCell; +using namespace WireCell::Aux::Test; using real_t = float; using RV = std::vector; using complex_t = std::complex; using CV = std::vector; -void test_1d(IDFT::pointer dft) +void test_1d_impulse(IDFT::pointer dft, int size = 64) { - RV rimp(64, 0); - rimp[1] = 1.0; + RV rimp(size, 0); + rimp[0] = 1.0; auto cimp = Aux::fwd(dft, Waveform::complex(rimp)); - for (auto c : cimp) { - std::cerr << c << " "; - } - std::cerr << "\n"; + assert_flat_value(cimp.data(), cimp.size()); RV rimp2 = Waveform::real(Aux::inv(dft, cimp)); - for (auto r : rimp2) { - std::cerr << r << " "; - } - std::cerr << "\n"; - for (int ind=0; ind<64; ++ind) { - if (ind == 1) { - assert(std::abs(rimp2[ind]-1.0) < 1e-6); - continue; - } - assert(std::abs(rimp2[ind]) < 1e-6); - } + assert_impulse_at_index(rimp2.data(), rimp2.size()); } -using FA = Eigen::Array; +using FA = Eigen::Array; +using CA = Eigen::Array; +using FARM = Eigen::Array; +using CARM = Eigen::Array; -void test_2d(IDFT::pointer dft) +void test_2d_impulse(IDFT::pointer dft, int nrows=16, int ncols=8) { - const int nrows=16; - const int ncols=8; + const size_t size = nrows*ncols; FA r = FA::Zero(nrows, ncols); - r(10,1) = 1.0; - std::cerr << r << std::endl; - auto c = Aux::fwd(dft, r.cast()); - std::cerr << c << std::endl; - FA r2 = Aux::inv(dft, c).real(); - std::cerr << r2 << std::endl; - for (int irow=0; irow -void dump(std::string name, const array_type& arr) -{ - std::cerr << name << ":(" << arr.rows() << "," << arr.cols() << ") row-major:" << arr.IsRowMajor << "\n"; + CA rc = r.cast(); + dump("rc", rc); + assert_impulse_at_index(rc.data(), size); + + CA c = Aux::fwd(dft, rc); + dump("c", c); + assert_flat_value(c.data(), size); + + FA r2 = Aux::inv(dft, c).real(); + dump("r2", r2); + assert_impulse_at_index(r2.data(), size); } -void test_2d_transpose(IDFT::pointer dft) +void test_2d_eigen_transpose(IDFT::pointer dft) { const int nrows=16; const int ncols=8; + // where the impulse lives (off axis) + const int imp_row = 1; + const int imp_col = 10; + FA r = FA::Zero(nrows, ncols); // shape:(16,8) dump("r", r); // do not remove the auto in this next line auto rt = r.transpose(); // shape:(8,16) dump("rt", rt); - rt(1,10) = 1.0; + rt(imp_row, imp_col) = 1.0; auto c = Aux::fwd(dft, rt.cast()); dump("c", c); @@ -95,7 +85,7 @@ void test_2d_transpose(IDFT::pointer dft) float val = rt(irow, icol); float val2 = r2(irow, icol); // access with transposed indices std::cerr << "(" << irow << ","<< icol << "):" << val << " ? " << val2 << "\n"; - if (irow==1 and icol==10) { + if (irow==imp_row and icol==imp_col) { assert(std::abs(val-1.0) < 1e-6); continue; } @@ -103,50 +93,41 @@ void test_2d_transpose(IDFT::pointer dft) } std::cerr << "\n"; } - - } -void test_1b(IDFT::pointer dft, int axis) +void test_1b(IDFT::pointer dft, int axis, int nrows=8, int ncols=4) { - const int nrows=8; - const int ncols=4; FA r = FA::Zero(nrows, ncols); - r(6,1) = 1.0; + r(0,0) = 1.0; dump("impulse", r); - std::cerr << r << std::endl; - auto c = Aux::fwd(dft, r.cast(), axis); - dump("spectra", c); - if (axis==0) { - - } - std::cerr << c << std::endl; -} -void test_1bt(IDFT::pointer dft, int axis) -{ - const int nrows=8; - const int ncols=4; - FA r = FA::Zero(nrows, ncols); - r(6,1) = 1.0; - auto rc = r.cast(); - auto rct = rc.transpose(); - dump("impulse.T", rct); - std::cerr << rct << std::endl; - auto c = Aux::fwd(dft, rct, axis); + CA c = Aux::fwd(dft, r.cast(), axis); + dump("spectra", c); std::cerr << c << std::endl; + + if (axis) { // transform along rows + CA ct = c.transpose(); // convert to along columns (native Eigen storage order) + c = ct; + std::swap(nrows, ncols); + dump("transpose", c); + std::cerr << c << std::endl; + } + + // first column has flat abs value of 1.0. + assert_flat_value(c.data(), nrows, complex_t(1,0)); + // rest should be flat, zero value + assert_flat_value(c.data()+nrows, nrows*ncols - nrows, complex_t(0,0)); + } int main() { auto dft = std::make_shared(); - test_1d(dft); - test_2d(dft); - test_2d_transpose(dft); + test_1d_impulse(dft); + test_2d_impulse(dft); + test_2d_eigen_transpose(dft); test_1b(dft, 0); test_1b(dft, 1); - test_1bt(dft, 0); - test_1bt(dft, 1); return 0; } diff --git a/aux/test/test_idft.cxx b/aux/test/test_idft.cxx index cfd109a8b..babc5ed9d 100644 --- a/aux/test/test_idft.cxx +++ b/aux/test/test_idft.cxx @@ -3,7 +3,8 @@ #include "WireCellUtil/Waveform.h" #include "WireCellUtil/PluginManager.h" #include "WireCellIface/IConfigurable.h" -#include "WireCellIface/IDFT.h" + +#include "aux_test_dft_helpers.h" #include #include @@ -13,31 +14,6 @@ using namespace WireCell; -const float eps = 1e-8; - -static void assert_impulse_at_index(const std::vector& vec, size_t index=0) -{ - const size_t size = vec.size(); - auto tot = Waveform::sum(vec); - assert(std::abs(std::real(tot) - 1.0) < eps); - assert(std::abs(std::real(vec[index]) - 1.0) < eps); - assert(std::abs(std::imag(tot)) < eps); - assert(std::abs(std::imag(vec[index])) < eps); - for (size_t ind=0; ind& vec, IDFT::scalar_t val = 1.0) -{ - const auto size = vec.size(); - auto tot = Waveform::sum(vec); - assert(std::abs(std::abs(tot) - val*size) < eps); - for (const auto& v : vec) { - assert(std::abs(std::abs(v) - val) < eps); - } -} static void test_1d_zero(IDFT::pointer dft, int size = 1024) From 28d9ec9129fb7f816e2c87cf55d769b25b995281 Mon Sep 17 00:00:00 2001 From: Brett Viren Date: Tue, 16 Nov 2021 12:11:59 -0500 Subject: [PATCH 13/46] More testing --- aux/test/aux_test_dft_helpers.h | 119 ++++++++++++++++++++++++++++++++ aux/test/test_dfttools.cxx | 15 ++-- aux/test/test_idft.cxx | 34 +++------ 3 files changed, 136 insertions(+), 32 deletions(-) create mode 100644 aux/test/aux_test_dft_helpers.h diff --git a/aux/test/aux_test_dft_helpers.h b/aux/test/aux_test_dft_helpers.h new file mode 100644 index 000000000..c4b12e012 --- /dev/null +++ b/aux/test/aux_test_dft_helpers.h @@ -0,0 +1,119 @@ +// This is only for sharing some common code betweeen different +// aux/test/*.cxx tests. Not for "real" use. + +#include "WireCellUtil/NamedFactory.h" +#include "WireCellUtil/PluginManager.h" +#include "WireCellUtil/Exceptions.h" + +#include "WireCellIface/IConfigurable.h" +#include "WireCellIface/IDFT.h" + +#include +#include + +namespace WireCell::Aux::Test { + + + // fixme: add support for config + IDFT::pointer make_dft(const std::string& tn="FftwDFT", + const std::string& pi="WireCellAux") + { + PluginManager& pm = PluginManager::instance(); + pm.add(pi); + + // create first + auto idft = Factory::lookup_tn(tn); + assert(idft); + // configure before use if configurable + auto icfg = Factory::find_maybe_tn(tn); + if (icfg) { + auto cfg = icfg->default_configuration(); + icfg->configure(cfg); + } + return idft; + } + IDFT::pointer make_dft_args(int argc, char* argv[]) + { + std::string dft_tn="FftwDFT"; + std::string dft_pi="WireCellAux"; + if (argc > 1) dft_tn = argv[1]; + if (argc > 2) dft_pi = argv[2]; + return make_dft(dft_tn, dft_pi); + } + + const double default_eps = 1e-8; + const std::complex czero = 0.0; + const std::complex cone = 1.0; + + void assert_small(double val, double eps = default_eps) { + if (val < eps) { + return; + } + std::stringstream ss; + ss << "value " << val << " >= " << eps; + std::cerr << ss.str() << std::endl; + THROW(WireCell::ValueError() << errmsg{ss.str()}); + } + + // Assert the array has only value val at index and near zero elsewhere + template + void assert_impulse_at_index(const ValueType* vec, size_t size, + size_t index=0, ValueType val = 1.0) + { + ValueType tot = 0; + for (size_t ind=0; ind + void assert_impulse_at_index(const VectorType& vec, + size_t index=0, const typename VectorType::value_type& val = 1.0) + { + assert_impulse_at_index(vec.data(), vec.size(), index, val); + } + + // Assert all values in array are near given val + template + void assert_flat_value(const ValueType* vec, size_t size, ValueType val = 1.0) + { + ValueType tot = 0; + for (size_t ind=0; ind + void assert_flat_value(const VectorType& vec, const typename VectorType::value_type& val = 1.0) + { + assert_flat_value(vec.data(), vec.size(), val); + } + + // Print eigen array + template + void dump(std::string name, const array_type& arr) + { + std::cerr << name << ":(" << arr.rows() << "," << arr.cols() << ") row-major:" << arr.IsRowMajor << "\n"; + } + + + // Like std::iota, but dummer + template + void iota(ValueType* vec, size_t size, ValueType start = 0) + { + for (size_t ind=0; ind(); + auto idft = make_dft_args(argc, argv); + + test_1d_impulse(idft); + test_2d_impulse(idft); + test_2d_eigen_transpose(idft); + test_1b(idft, 0); + test_1b(idft, 1); - test_1d_impulse(dft); - test_2d_impulse(dft); - test_2d_eigen_transpose(dft); - test_1b(dft, 0); - test_1b(dft, 1); return 0; } diff --git a/aux/test/test_idft.cxx b/aux/test/test_idft.cxx index babc5ed9d..1b47646b8 100644 --- a/aux/test/test_idft.cxx +++ b/aux/test/test_idft.cxx @@ -1,8 +1,5 @@ // Test IDFT implementations. -#include "WireCellUtil/NamedFactory.h" #include "WireCellUtil/Waveform.h" -#include "WireCellUtil/PluginManager.h" -#include "WireCellIface/IConfigurable.h" #include "aux_test_dft_helpers.h" @@ -13,6 +10,7 @@ #include using namespace WireCell; +using namespace WireCell::Aux::Test; static @@ -21,9 +19,9 @@ void test_1d_zero(IDFT::pointer dft, int size = 1024) std::vector inter(size,0), freq(size,0); dft->fwd1d(inter.data(), freq.data(), inter.size()); - assert_flat_value(freq, 0); + assert_flat_value(freq, czero); dft->inv1d(freq.data(), inter.data(), freq.size()); - assert_flat_value(inter, 0); + assert_flat_value(inter, czero); } static void test_1d_impulse(IDFT::pointer dft, int size=1024) @@ -76,18 +74,18 @@ static void assert_on_axis(const std::vector& freq, auto val = std::abs(freq[ind]); if (axis) { if (irow==0) { - assert(std::abs(val - 1.0) < eps); + assert_small(std::abs(val - 1.0)); } else { - assert(val < eps); + assert_small(val); } } else { if (icol==0) { - assert(std::abs(val - 1.0) < eps); + assert_small(std::abs(val - 1.0)); } else { - assert(val < eps); + assert_small(val); } } } @@ -197,24 +195,10 @@ void test_2d_transpose(IDFT::pointer dft, int nrows, int ncols) } + int main(int argc, char* argv[]) { - // fixme, add CLI parsing to add plugins, config and name another - // dft. For now, just use the one in aux. - PluginManager& pm = PluginManager::instance(); - pm.add("WireCellAux"); - std::string dft_tn = "FftwDFT"; - - // creates - auto idft = Factory::lookup_tn(dft_tn); - assert(idft); - { // configure before use if configurable - auto icfg = Factory::find_maybe_tn(dft_tn); - if (icfg) { - auto cfg = icfg->default_configuration(); - icfg->configure(cfg); - } - } + auto idft = make_dft_args(argc, argv); test_1d_zero(idft); test_1d_impulse(idft); From c798cd4ba874353f64efb0a061e87f8ea7116e45 Mon Sep 17 00:00:00 2001 From: Brett Viren Date: Tue, 16 Nov 2021 15:03:45 -0500 Subject: [PATCH 14/46] Make a semaphore interface and implement with what is in util --- aux/inc/WireCellAux/Semaphore.h | 34 +++++++++++++++++++ aux/src/Semaphore.cxx | 51 ++++++++++++++++++++++++++++ iface/inc/WireCellIface/ISemaphore.h | 31 +++++++++++++++++ iface/src/IfaceDesctructors.cxx | 2 ++ 4 files changed, 118 insertions(+) create mode 100644 aux/inc/WireCellAux/Semaphore.h create mode 100644 aux/src/Semaphore.cxx create mode 100644 iface/inc/WireCellIface/ISemaphore.h diff --git a/aux/inc/WireCellAux/Semaphore.h b/aux/inc/WireCellAux/Semaphore.h new file mode 100644 index 000000000..1385bf51d --- /dev/null +++ b/aux/inc/WireCellAux/Semaphore.h @@ -0,0 +1,34 @@ +/** Implement a semaphore component interace. */ + +#ifndef WIRECELLAUX_SEMAPHORE +#define WIRECELLAUX_SEMAPHORE + +#include "WireCellIface/IConfigurable.h" +#include "WireCellIface/ISemaphore.h" +#include "WireCellUtil/Semaphore.h" + + +namespace WireCell::Aux { + class Semaphore : public ISemaphore, + public IConfigurable + { + public: + Semaphore(); + virtual ~Semaphore(); + + // IConfigurable interface + virtual void configure(const WireCell::Configuration& config); + virtual WireCell::Configuration default_configuration() const; + + // ISemaphore + virtual void acquire() const; + virtual void release() const; + + private: + + mutable FastSemaphore m_sem; + + }; +} // namespace WireCell::Pytorch + +#endif // WIRECELLPYTORCH_TORCHSERVICE diff --git a/aux/src/Semaphore.cxx b/aux/src/Semaphore.cxx new file mode 100644 index 000000000..d76999837 --- /dev/null +++ b/aux/src/Semaphore.cxx @@ -0,0 +1,51 @@ +#include "WireCellAux/Semaphore.h" + +#include "WireCellUtil/NamedFactory.h" +#include "WireCellUtil/Semaphore.h" + +WIRECELL_FACTORY(Semaphore, + WireCell::Aux::Semaphore, + WireCell::ISemaphore, + WireCell::IConfigurable) + +using namespace WireCell; + +Aux::Semaphore::Semaphore() + : m_sem(0) +{ +} +Aux::Semaphore::~Semaphore() +{ +} + +WireCell::Configuration Aux::Semaphore::default_configuration() const +{ + Configuration cfg; + + // The maximum allowed number concurrent calls to forward(). A + // value of unity means all calls will be serialized. When made + // smaller than the number of threads, the difference gives the + // number of threads that may block on the semaphore. + cfg["concurrency"] = 1; + + return cfg; +} + +void Aux::Semaphore::configure(const WireCell::Configuration& cfg) +{ + auto count = get(cfg, "concurrency", 1); + if (count < 1 ) { + count = 1; + } + m_sem.set_count(count); +} + +void Aux::Semaphore::acquire() const +{ + m_sem.acquire(); +} + +void Aux::Semaphore::release() const +{ + m_sem.release(); +} diff --git a/iface/inc/WireCellIface/ISemaphore.h b/iface/inc/WireCellIface/ISemaphore.h new file mode 100644 index 000000000..99a55396d --- /dev/null +++ b/iface/inc/WireCellIface/ISemaphore.h @@ -0,0 +1,31 @@ +/** An interface to the semaphore pattern */ + +#ifndef WIRECELL_ISEMAPHORE +#define WIRECELL_ISEMAPHORE + +#include "WireCellUtil/IComponent.h" + +namespace WireCell { + class ISemaphore : public IComponent { + public: + virtual ~ISemaphore(); + + /// Block until available spot to hold the semaphore is + /// available. + virtual void acquire() const = 0; + + /// Release hold on the semaphore + virtual void release() const = 0; + + /// Use Construct a Context on a semaphore in a local scope to + /// automate release + struct Context { + ISemaphore::pointer sem; + Context(ISemaphore::pointer sem) : sem(sem) { sem->acquire(); } + ~Context() { sem->release(); } + }; + + }; +} // namespace WireCell + +#endif // WIRECELL_ITENSORFORWARD diff --git a/iface/src/IfaceDesctructors.cxx b/iface/src/IfaceDesctructors.cxx index 76efb5ecd..59b1597d1 100644 --- a/iface/src/IfaceDesctructors.cxx +++ b/iface/src/IfaceDesctructors.cxx @@ -74,6 +74,7 @@ #include "WireCellIface/IRandom.h" #include "WireCellIface/IRecombinationModel.h" #include "WireCellIface/IScalarFieldSink.h" +#include "WireCellIface/ISemaphore.h" #include "WireCellIface/ISequence.h" #include "WireCellIface/ISinkNode.h" #include "WireCellIface/ISlice.h" @@ -172,6 +173,7 @@ IQueuedoutNodeBase::~IQueuedoutNodeBase() {} IRandom::~IRandom() {} IRecombinationModel::~IRecombinationModel() {} IScalarFieldSink::~IScalarFieldSink() {} +ISemaphore::~ISemaphore() {} ISinkNodeBase::~ISinkNodeBase() {} ISlice::~ISlice() {} ISliceFanout::~ISliceFanout() {} From fc8dc584d0355aba82f460d7ac98dea11947cfa2 Mon Sep 17 00:00:00 2001 From: Brett Viren Date: Tue, 16 Nov 2021 15:04:16 -0500 Subject: [PATCH 15/46] Move common code to a 'context' mix-in, add initial torch imp of IDFT --- pytorch/inc/WireCellPytorch/DFT.h | 62 +++++++++++ pytorch/inc/WireCellPytorch/TorchContext.h | 62 +++++++++++ pytorch/inc/WireCellPytorch/TorchService.h | 9 +- pytorch/src/DFT.cxx | 118 +++++++++++++++++++++ pytorch/src/TorchContext.cxx | 38 +++++++ pytorch/src/TorchService.cxx | 68 ++---------- 6 files changed, 294 insertions(+), 63 deletions(-) create mode 100644 pytorch/inc/WireCellPytorch/DFT.h create mode 100644 pytorch/inc/WireCellPytorch/TorchContext.h create mode 100644 pytorch/src/DFT.cxx create mode 100644 pytorch/src/TorchContext.cxx diff --git a/pytorch/inc/WireCellPytorch/DFT.h b/pytorch/inc/WireCellPytorch/DFT.h new file mode 100644 index 000000000..41e154152 --- /dev/null +++ b/pytorch/inc/WireCellPytorch/DFT.h @@ -0,0 +1,62 @@ +/** + TorchDFT provides a libtorch based implementation of IDFT. + + The libtorch API is documented at: + + https://pytorch.org/cppdocs/api/namespace_torch__fft.html + */ + +#ifndef WIRECELL_PYTORCH_DFT +#define WIRECELL_PYTORCH_DFT + +#include "WireCellIface/IDFT.h" +#include "WireCellIface/IConfigurable.h" +#include "WireCellPytorch/TorchContext.h" + +namespace WireCell::Pytorch { + class DFT : public IDFT, + public IConfigurable + { + public: + DFT(); + virtual ~DFT(); + + // IConfigurable interface + virtual void configure(const WireCell::Configuration& config); + virtual WireCell::Configuration default_configuration() const; + + // 1d + + virtual + void fwd1d(const complex_t* in, complex_t* out, + int size) const; + + virtual + void inv1d(const complex_t* in, complex_t* out, + int size) const; + + // batched 1D ("1b") - rely on base implementation + virtual + void fwd1b(const complex_t* in, complex_t* out, + int nrows, int ncols, int axis) const; + virtual + void inv1b(const complex_t* in, complex_t* out, + int nrows, int ncols, int axis) const; + + // 2d + + virtual + void fwd2d(const complex_t* in, complex_t* out, + int nrows, int ncols) const; + virtual + void inv2d(const complex_t* in, complex_t* out, + int nrows, int ncols) const; + + private: + TorchContext m_ctx; + + }; + +} + +#endif diff --git a/pytorch/inc/WireCellPytorch/TorchContext.h b/pytorch/inc/WireCellPytorch/TorchContext.h new file mode 100644 index 000000000..66979a1c5 --- /dev/null +++ b/pytorch/inc/WireCellPytorch/TorchContext.h @@ -0,0 +1,62 @@ +/** A mixin class to provide a torch context + + */ + +#include "WireCellIface/ISemaphore.h" +#include + +namespace WireCell::Pytorch { + + class TorchContext { + public: + + // The "devname" is "cpu" or "gpu" or "gpuN" where N is a GPU + // number. If "semname" is given, use it for semaphore, + // otherwise use canonically tn=Semaphore:torch-. + TorchContext(const std::string& devname, + const std::string& semname=""); + TorchContext(); + ~TorchContext(); + + // Default constructor makes context with no device nor + // semaphore. This will make the "connection" to them. + void connect(const std::string& devname, + const std::string& semname=""); + + torch::Device device() const { return m_dev; } + std::string devname() const { return m_devname; } + + bool is_gpu() const { return m_devname != "cpu"; } + + // Context manager methods. Caller should prefer using a + // TorchSemaphore class but if called directly, caller MUST + // balance an enter() with an exit(). These can and should be + // used in multi-thread run stage. + void enter() const { if (m_sem) m_sem->acquire(); } + void exit() const { if (m_sem) m_sem->release(); } + + private: + + torch::Device m_dev{torch::kCPU}; + std::string m_devname; + ISemaphore::pointer m_sem; + }; + + /// Use like: + /// + /// void mymeth() { + /// TorchSemaphore sem(m_ctx); + /// ... more code may return/throw + /// } // end of scope + class TorchSemaphore { + const TorchContext& m_th; + public: + TorchSemaphore(const TorchContext& th) : m_th(th) { + m_th.enter(); + } + ~TorchSemaphore() { + m_th.exit(); + } + }; + +} diff --git a/pytorch/inc/WireCellPytorch/TorchService.h b/pytorch/inc/WireCellPytorch/TorchService.h index 01504724e..ca8fe4eb1 100644 --- a/pytorch/inc/WireCellPytorch/TorchService.h +++ b/pytorch/inc/WireCellPytorch/TorchService.h @@ -7,7 +7,7 @@ #include "WireCellIface/ITensorForward.h" #include "WireCellUtil/Logging.h" #include "WireCellAux/Logger.h" -#include "WireCellUtil/Semaphore.h" +#include "WireCellPytorch/TorchContext.h" #include // One-stop header. @@ -29,15 +29,14 @@ namespace WireCell::Pytorch { private: - // Mark which device is used - torch::Device m_dev; - // for read-only access, claim is that .forward() is thread // safe. However .forward() is not const so we must make this // mutable. mutable torch::jit::script::Module m_module; - mutable FastSemaphore m_sem; + // Even though thread safe, we want to honor a per device + // semaphore to give user chance ot limit us. + TorchContext m_ctx; }; } // namespace WireCell::Pytorch diff --git a/pytorch/src/DFT.cxx b/pytorch/src/DFT.cxx new file mode 100644 index 000000000..cba68164e --- /dev/null +++ b/pytorch/src/DFT.cxx @@ -0,0 +1,118 @@ +#include "WireCellPytorch/DFT.h" +#include "WireCellUtil/NamedFactory.h" + +#include +#include + + +WIRECELL_FACTORY(FftwDFT, WireCell::Pytorch::DFT, + WireCell::IDFT, + WireCell::IConfigurable) + +using namespace WireCell; +using namespace WireCell::Pytorch; + +DFT::DFT() +{ +} + +DFT::~DFT() +{ +} + +Configuration DFT::default_configuration() const +{ + Configuration cfg; + + // one of: {cpu, gpu, gpuN} where "N" is a GPU number. "gpu" + // alone will use GPU 0. + cfg["device"] = "cpu"; + return cfg; +} + +void DFT::configure(const WireCell::Configuration& cfg) +{ + auto dev = get(cfg, "device", "cpu"); + m_ctx.connect(dev); +} + + +using torch_transform = std::function; + +static +void doit(const TorchContext& ctx, + const IDFT::complex_t* in, IDFT::complex_t* out, + int64_t nrows, int64_t ncols, // 1d vec should have nrows=1 + torch_transform func) +{ + TorchSemaphore sem(ctx); + torch::NoGradGuard no_grad; + + int64_t size = nrows*ncols; + + auto options = torch::TensorOptions().device(ctx.device()).dtype(torch::kComplexFloat); + + // 1) in->src + if (in != out) { // from_blob() doesn't like const data + memcpy(out, in, sizeof(IDFT::complex_t)*size); + } + + torch::Tensor src = torch::from_blob(out, {nrows, ncols}, options); + + // 2) dst = func(src) + auto dst = func(src); + dst = dst.cpu(); + + // 3) dst->out + if (out != dst.data_ptr()) { + memcpy(out, dst.data_ptr(), sizeof(IDFT::complex_t)*size); + } + +} + + +void DFT::fwd1d(const IDFT::complex_t* in, IDFT::complex_t* out, int size) const +{ + doit(m_ctx, in, out, 1, size, + [](const torch::Tensor& src) { return torch::fft::fft(src); }); +} + + +void DFT::inv1d(const IDFT::complex_t* in, IDFT::complex_t* out, int size) const +{ + doit(m_ctx, in, out, 1, size, // fixme: check norm + [](const torch::Tensor& src) { return torch::fft::ifft(src); }); +} + + +void DFT::fwd1b(const IDFT::complex_t* in, IDFT::complex_t* out, + int nrows, int ncols, int axis) const +{ + doit(m_ctx, in, out, nrows, ncols, [&](const torch::Tensor& src) { + return torch::fft::fft2(src, torch::nullopt, {axis}); }); +} + + +void DFT::inv1b(const IDFT::complex_t* in, IDFT::complex_t* out, + int nrows, int ncols, int axis) const +{ + doit(m_ctx, in, out, nrows, ncols, [&](const torch::Tensor& src) { + return torch::fft::ifft2(src, torch::nullopt, {axis}); }); +} + + +void DFT::fwd2d(const IDFT::complex_t* in, IDFT::complex_t* out, + int nrows, int ncols) const +{ + doit(m_ctx, in, out, nrows, ncols, + [](const torch::Tensor& src) { return torch::fft::fft2(src); }); +} + + +void DFT::inv2d(const IDFT::complex_t* in, IDFT::complex_t* out, + int nrows, int ncols) const +{ + doit(m_ctx, in, out, nrows, ncols, + [](const torch::Tensor& src) { return torch::fft::ifft2(src); }); +} + diff --git a/pytorch/src/TorchContext.cxx b/pytorch/src/TorchContext.cxx new file mode 100644 index 000000000..e5de936b8 --- /dev/null +++ b/pytorch/src/TorchContext.cxx @@ -0,0 +1,38 @@ +#include "WireCellPytorch/TorchContext.h" +#include "WireCellUtil/NamedFactory.h" + +using namespace WireCell; +using namespace WireCell::Pytorch; + +TorchContext::TorchContext() {} +TorchContext::~TorchContext() { } +TorchContext::TorchContext(const std::string& devname, + const std::string& semname) +{ + connect(devname, semname); +} +void TorchContext::connect(const std::string& devname, + const std::string& semname) +{ + // Use almost 1/2 the memory and 3/4 the time. + torch::NoGradGuard no_grad; + + if (devname == "cpu") { + m_dev = torch::Device(torch::kCPU); + } + else { + int devnum = 0; + if (devname.size() > 3) { + devnum = atoi(devname.substr(3).c_str()); + } + m_dev = torch::Device(torch::kCUDA, devnum); + } + + std::string s_tn = "Semaphore:torch-" + devname; + if (not semname.empty()) { + s_tn = semname; + } + + m_sem = Factory::find_tn(s_tn); +} + diff --git a/pytorch/src/TorchService.cxx b/pytorch/src/TorchService.cxx index 76bbfe436..b78a3a8d0 100644 --- a/pytorch/src/TorchService.cxx +++ b/pytorch/src/TorchService.cxx @@ -14,8 +14,6 @@ using namespace WireCell; Pytorch::TorchService::TorchService() : Aux::Logger("TorchService", "torch") - , m_dev(torch::kCPU, 0) - , m_sem(0) { } @@ -26,30 +24,17 @@ Configuration Pytorch::TorchService::default_configuration() const // TorchScript model cfg["model"] = "model.ts"; - // one of: {cpu, gpu, gpucpu}. Latter allows fail-over to cpu - // when there is a failure to load the model. - // fixme: we may want to allow user to give a GPU index number - // here so like eg gpu:1, gpucpu:2. An index is not meaningful - // for cpu. - cfg["device"] = "gpucpu"; + // one of: {cpu, gpu, gpuN} where "N" is a GPU number. "gpu" + // alone will use GPU 0. + cfg["device"] = "cpu"; - // The maximum allowed number concurrent calls to forward(). A - // value of unity means all calls will be serialized. When made - // smaller than the number of threads, the difference gives the - // number of threads that may block on the semaphore. - cfg["concurrency"] = 1; - return cfg; } void Pytorch::TorchService::configure(const WireCell::Configuration& cfg) { - auto dev = get(cfg, "device", "gpucpu"); - auto count = get(cfg, "concurrency", 1); - if (count < 1 ) { - count = 1; - } - m_sem.set_count(count); + auto dev = get(cfg, "device", "cpu"); + m_ctx.connect(dev); auto model_path = cfg["model"].asString(); if (model_path.empty()) { @@ -58,34 +43,10 @@ void Pytorch::TorchService::configure(const WireCell::Configuration& cfg) } // Use almost 1/2 the memory and 3/4 the time. - // but, fixme: check with Haiwng that this is okay. torch::NoGradGuard no_grad; - // Maybe first try to load torch script model on GPU. - if (dev == "gpucpu") { - try { - m_dev = torch::Device(torch::kCUDA, 0); - m_module = torch::jit::load(model_path, m_dev); - log->debug("loaded model {} to {}", model_path, dev); - return; - } - catch (const c10::Error& e) { - log->warn("failed to load model: {} to GPU will try CPU: {}", - model_path, e.what()); - } - } - - if (dev == "cpu") { - m_dev = torch::Device(torch::kCPU); - } - else { - m_dev = torch::Device(torch::kCUDA, 0); - } - - // from now, we either succeed or we throw - try { - m_module = torch::jit::load(model_path, m_dev); + m_module = torch::jit::load(model_path, m_ctx.device()); } catch (const c10::Error& e) { log->critical("error loading model: {} to {}: {}", @@ -96,19 +57,15 @@ void Pytorch::TorchService::configure(const WireCell::Configuration& cfg) log->debug("loaded model {} to {}", model_path, dev); } -#include - ITensorSet::pointer Pytorch::TorchService::forward(const ITensorSet::pointer& in) const { + TorchSemaphore sem(m_ctx); - m_sem.acquire(); - - const bool is_gpu = ! (m_dev == torch::kCPU); - log->debug("running model on {}", is_gpu ? "GPU" : "CPU"); + log->debug("running model on {}", m_ctx.devname()); torch::NoGradGuard no_grad; - std::vector iival = Pytorch::from_itensor(in, is_gpu); + std::vector iival = Pytorch::from_itensor(in, m_ctx.is_gpu()); torch::IValue oival; try { @@ -116,16 +73,11 @@ ITensorSet::pointer Pytorch::TorchService::forward(const ITensorSet::pointer& in } catch (const std::runtime_error& err) { log->error("error running model on {}: {}", - is_gpu ? "GPU" : "CPU", err.what()); - m_sem.release(); + m_ctx.devname(), err.what()); return nullptr; } ITensorSet::pointer ret = Pytorch::to_itensor({oival}); - // maybe needs a mutex? - c10::cuda::CUDACachingAllocator::emptyCache(); - - m_sem.release(); return ret; } From 22d5ea0bd66fb447d24f21b55dc3937c994ce4c6 Mon Sep 17 00:00:00 2001 From: Brett Viren Date: Wed, 17 Nov 2021 12:30:19 -0500 Subject: [PATCH 16/46] Work out brain bugs in understanding torch tensor storage --- aux/src/Semaphore.cxx | 2 +- aux/test/aux_test_dft_helpers.h | 31 +++++++++++-- aux/test/test_idft.cxx | 34 ++++++++------ aux/test/test_idft_pytorch.jsonnet | 7 +++ iface/src/IDFT.cxx | 2 +- pytorch/src/DFT.cxx | 22 +++++---- pytorch/src/TorchContext.cxx | 2 +- pytorch/test/test_from_blob.cxx | 73 ++++++++++++++++++++++++++++++ 8 files changed, 146 insertions(+), 27 deletions(-) create mode 100644 aux/test/test_idft_pytorch.jsonnet create mode 100644 pytorch/test/test_from_blob.cxx diff --git a/aux/src/Semaphore.cxx b/aux/src/Semaphore.cxx index d76999837..841debbb7 100644 --- a/aux/src/Semaphore.cxx +++ b/aux/src/Semaphore.cxx @@ -11,7 +11,7 @@ WIRECELL_FACTORY(Semaphore, using namespace WireCell; Aux::Semaphore::Semaphore() - : m_sem(0) + : m_sem(1) { } Aux::Semaphore::~Semaphore() diff --git a/aux/test/aux_test_dft_helpers.h b/aux/test/aux_test_dft_helpers.h index c4b12e012..d83754976 100644 --- a/aux/test/aux_test_dft_helpers.h +++ b/aux/test/aux_test_dft_helpers.h @@ -4,6 +4,7 @@ #include "WireCellUtil/NamedFactory.h" #include "WireCellUtil/PluginManager.h" #include "WireCellUtil/Exceptions.h" +#include "WireCellUtil/Persist.h" #include "WireCellIface/IConfigurable.h" #include "WireCellIface/IDFT.h" @@ -16,8 +17,11 @@ namespace WireCell::Aux::Test { // fixme: add support for config IDFT::pointer make_dft(const std::string& tn="FftwDFT", - const std::string& pi="WireCellAux") + const std::string& pi="WireCellAux", + Configuration cfg = Configuration()) { + std::cerr << "Making DFT " << tn << " from plugin " << pi << std::endl; + PluginManager& pm = PluginManager::instance(); pm.add(pi); @@ -27,8 +31,9 @@ namespace WireCell::Aux::Test { // configure before use if configurable auto icfg = Factory::find_maybe_tn(tn); if (icfg) { - auto cfg = icfg->default_configuration(); - icfg->configure(cfg); + auto def_cfg = icfg->default_configuration(); + def_cfg = update(def_cfg, cfg); + icfg->configure(def_cfg); } return idft; } @@ -38,6 +43,26 @@ namespace WireCell::Aux::Test { std::string dft_pi="WireCellAux"; if (argc > 1) dft_tn = argv[1]; if (argc > 2) dft_pi = argv[2]; + Configuration cfg; + if (argc > 3) { + // Either we get directly a "data" object + cfg = Persist::load(argv[3]); + // or we go searching a list for matching type/name. + if (cfg.isArray()) { + for (auto one : cfg) { + std::string tn = get(one, "type"); + std::string n = get(one, "name", ""); + if (not n.empty()) { + tn = tn + ":" + n; + } + if (tn == dft_tn) { + cfg = one["data"]; + break; + } + } + } + + } return make_dft(dft_tn, dft_pi); } diff --git a/aux/test/test_idft.cxx b/aux/test/test_idft.cxx index 1b47646b8..a4f337a5e 100644 --- a/aux/test/test_idft.cxx +++ b/aux/test/test_idft.cxx @@ -12,6 +12,17 @@ using namespace WireCell; using namespace WireCell::Aux::Test; +template +void dump(ValueType* data, int nrows, int ncols, std::string msg="") +{ + std::cerr << msg << "("<& freq, void test_1b_impulse(IDFT::pointer dft, int axis, int nrows=128, int ncols=128) { const int size = nrows*ncols; + std::cerr << "1b impulse freq axis="< inter(size,0), freq(size,0), back(size,0); inter[0] = 1.0; + dft->fwd1b(inter.data(), freq.data(), nrows, ncols, axis); + dump(freq.data(), nrows, ncols, "freq"); assert_on_axis(freq, axis, nrows, ncols); + dft->inv1b(freq.data(), back.data(), nrows, ncols, axis); + dump(back.data(), nrows, ncols, "back"); assert_impulse_at_index(back, 0); + std::vector inplace(size,0); inplace[0] = 1.0; dft->fwd1b(inplace.data(), inplace.data(), nrows, ncols, axis); @@ -159,17 +176,6 @@ void test_2d_threads(IDFT::pointer dft, int nthreads, int nloops, int size = 102 << " " << dt1.count() << std::endl; } -template -void dump(ValueType* data, int nrows, int ncols, std::string msg="") -{ - std::cerr << msg << "("< void test_2d_transpose(IDFT::pointer dft, int nrows, int ncols) @@ -205,8 +211,10 @@ int main(int argc, char* argv[]) test_2d_zero(idft); test_2d_impulse(idft); - test_1b_impulse(idft, 0); - test_1b_impulse(idft, 1); + test_1b_impulse(idft, 0, 2, 8); + test_1b_impulse(idft, 1, 2, 8); + test_1b_impulse(idft, 0, 8, 2); + test_1b_impulse(idft, 1, 8, 2); test_2d_transpose(idft, 2, 8); test_2d_transpose(idft, 8, 2); diff --git a/aux/test/test_idft_pytorch.jsonnet b/aux/test/test_idft_pytorch.jsonnet new file mode 100644 index 000000000..10e1f788a --- /dev/null +++ b/aux/test/test_idft_pytorch.jsonnet @@ -0,0 +1,7 @@ +// a configuration "data" portion for TorchDFT. +// call like: +// ❯ ./build/aux/test_idft TorchDFT WireCellPytorch aux/test/test_idft_pytorch.jsonnet +{ + device: "gpu", +} + diff --git a/iface/src/IDFT.cxx b/iface/src/IDFT.cxx index 58b1d335b..abc3d6120 100644 --- a/iface/src/IDFT.cxx +++ b/iface/src/IDFT.cxx @@ -66,7 +66,7 @@ void transpose_type(const ValueType* in, ValueType* out, const int size = nrows*ncols; const int mn1 = (size - 1); std::vector visited(size); - ValueType* first = out + size; + ValueType* first = out; const ValueType* last = first + size; ValueType* cycle = out; while (++cycle != last) { diff --git a/pytorch/src/DFT.cxx b/pytorch/src/DFT.cxx index cba68164e..052013ded 100644 --- a/pytorch/src/DFT.cxx +++ b/pytorch/src/DFT.cxx @@ -5,7 +5,7 @@ #include -WIRECELL_FACTORY(FftwDFT, WireCell::Pytorch::DFT, +WIRECELL_FACTORY(TorchDFT, WireCell::Pytorch::DFT, WireCell::IDFT, WireCell::IConfigurable) @@ -50,24 +50,30 @@ void doit(const TorchContext& ctx, int64_t size = nrows*ncols; - auto options = torch::TensorOptions().device(ctx.device()).dtype(torch::kComplexFloat); + auto dtype = torch::TensorOptions().dtype(torch::kComplexFloat); // 1) in->src if (in != out) { // from_blob() doesn't like const data memcpy(out, in, sizeof(IDFT::complex_t)*size); } - torch::Tensor src = torch::from_blob(out, {nrows, ncols}, options); + torch::Tensor src = torch::from_blob(out, {nrows, ncols}, dtype); // 2) dst = func(src) + src = src.to(ctx.device()); auto dst = func(src); - dst = dst.cpu(); + + // Making contiguous costs a copy but gets the data in row-major + // so the (2nd) copy next actually gives correct results. This + // corrects optimizations that libtorch makes for transpose (and + // others) eg when our func is 1b. Alternatively, may avoid both + // copies by iterating over indices but presumably (?) that is + // slower. Likewise we make contiguous on the device as that is + // presumably (?) faster when the device is GPU. + dst = dst.contiguous().cpu(); // 3) dst->out - if (out != dst.data_ptr()) { - memcpy(out, dst.data_ptr(), sizeof(IDFT::complex_t)*size); - } - + memcpy(out, dst.data_ptr(), sizeof(IDFT::complex_t)*size); } diff --git a/pytorch/src/TorchContext.cxx b/pytorch/src/TorchContext.cxx index e5de936b8..ed219f93b 100644 --- a/pytorch/src/TorchContext.cxx +++ b/pytorch/src/TorchContext.cxx @@ -33,6 +33,6 @@ void TorchContext::connect(const std::string& devname, s_tn = semname; } - m_sem = Factory::find_tn(s_tn); + m_sem = Factory::lookup_tn(s_tn); } diff --git a/pytorch/test/test_from_blob.cxx b/pytorch/test/test_from_blob.cxx new file mode 100644 index 000000000..b6e45e23e --- /dev/null +++ b/pytorch/test/test_from_blob.cxx @@ -0,0 +1,73 @@ +#include +#include + +#include +#include +#include + +using complex_t = std::complex; + +void dump(const std::vector& v, int nrows, int ncols, const std::string& msg="") +{ + std::cerr << msg << ": ("< v(size, 0); + //std::iota(v.begin(), v.end(), 0); + v[ncols+2] = 1.0; + dump(v, nrows, ncols, "v"); + + // Note: gpu is almost 10x SLOWER than CPU due to kernel load time! + // auto device = at::Device(at::kCPU); + auto device = at::Device(at::kCUDA); + + auto typ_options = at::TensorOptions().dtype(at::kComplexFloat); + // auto dev_options = typ_options.device(device); + + for (int axis = 0; axis < 2; ++ axis) { + at::Tensor src = at::from_blob(v.data(), {nrows, ncols}, typ_options); + dump(src, "src"); + + src = src.to(device); + // In pytorch dim=(0,) transforms along columns, ie follows + // numpy.fft convention. Both directions work on both CPU and + // CUDA. + at::Tensor dst = torch::fft::fft2(src, {}, {axis,}); + + // BUT, BEWARE that the underlying storage will NOT reflect + // logical row-major ordering. Indexing is as expected but + // memory returned by data_ptr() will reflect transpose + // optimizations. At the expense of a copy the contiguous() + // method provides expected row-major storage order. + dst = dst.contiguous().cpu(); + + dump(dst, "dst"); + + std::vector v2(size, 0); + memcpy(v2.data(), dst.data_ptr(), sizeof(complex_t)*size); + + std::cerr << "axis=" << axis << " dim=" << axis + << " shape=(" << src.size(0) << "," << src.size(1) << ")\n"; + dump(v2, nrows, ncols, "dft(v)"); + } + return 0; +} From 47e67dd12f086ea02a3066deb48320791dd77cc3 Mon Sep 17 00:00:00 2001 From: Brett Viren Date: Thu, 18 Nov 2021 08:40:10 -0500 Subject: [PATCH 17/46] Typo in plugin name fixed --- cfg/pgrapher/common/helpers/utils.jsonnet | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cfg/pgrapher/common/helpers/utils.jsonnet b/cfg/pgrapher/common/helpers/utils.jsonnet index 19e31618b..08ed7781f 100644 --- a/cfg/pgrapher/common/helpers/utils.jsonnet +++ b/cfg/pgrapher/common/helpers/utils.jsonnet @@ -18,7 +18,7 @@ local pg = import "pgraph.jsonnet"; local app_plugins = { 'TbbFlow': ["WireCellTbb"], - 'PGrapher': ["WireCellPgraph"], + 'PGrapher': ["WireCellPgrapher"], }, main(graph, app='Pgrapher', extra_plugins = []) :: From d90bdcc8dfe32b073765cb882499be685f7b301e Mon Sep 17 00:00:00 2001 From: Brett Viren Date: Thu, 18 Nov 2021 08:40:23 -0500 Subject: [PATCH 18/46] Initial draft of an IDFT benchmarker --- aux/test/test_idft_bench.cxx | 100 +++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 aux/test/test_idft_bench.cxx diff --git a/aux/test/test_idft_bench.cxx b/aux/test/test_idft_bench.cxx new file mode 100644 index 000000000..01b10d91d --- /dev/null +++ b/aux/test/test_idft_bench.cxx @@ -0,0 +1,100 @@ +/** + A simple benchmark of IDFT for payloads relevant to WCT + */ + +#include "WireCellUtil/TimeKeeper.h" + +#include "aux_test_dft_helpers.h" + +#include +#include +#include + +using namespace WireCell; +using namespace WireCell::Aux::Test; + +using benchmark_function = std::function; +using complex_t = std::complex; + +void timeit(TimeKeeper& tk, const std::string& msg, benchmark_function func) +{ + tk("\tINIT\t" + msg); + func(); + tk("\tFINI\t" + msg); +} + +// benchmarks span outer product of: +// - in-place / out-place +// - 1d, 1b, 2d +// - sizes: perfect powers of 2 and with larger prime factors +// - use repitition numbers to keep each test roughly same runtime + +using transform_function = std::function; + +const int onedfull = 100'000'000; +void doit(TimeKeeper& tk, const std::string& name, int nrows, int ncols, bool inplace, transform_function func) +{ + const int size = nrows*ncols; + const int ntimes = onedfull / size; + std::stringstream ss; + ss << "\t(" << nrows << "," << ncols << ")\t" << ntimes << "\t"; + std::string s = ss.str(); + + if (inplace) { + timeit(tk, s + "in-place\t" + name, [&]() { + std::vector in(size); + for (int count=0; count in(size), out(size); + for (int count=0; countfwd1d(in, out, size); + }); + + doit(tk, "inv1d", 1, size, false, [&](const complex_t* in, complex_t* out) { + idft->inv1d(in, out, size); + }); + + int nrows = 1000; + int ncols = 1000; + doit(tk, "fwd2d", nrows, ncols, false, [&](const complex_t* in, complex_t* out) { + idft->fwd2d(in, out, nrows, ncols); + }); + doit(tk, "inv2d", nrows, ncols, false, [&](const complex_t* in, complex_t* out) { + idft->inv2d(in, out, nrows, ncols); + }); + + doit(tk, "fwd1b0", nrows, ncols, false, [&](const complex_t* in, complex_t* out) { + idft->fwd1b(in, out, nrows, ncols, 0); + }); + doit(tk, "inv1b0", nrows, ncols, false, [&](const complex_t* in, complex_t* out) { + idft->inv1b(in, out, nrows, ncols, 0); + }); + doit(tk, "fwd1b1", nrows, ncols, false, [&](const complex_t* in, complex_t* out) { + idft->fwd1b(in, out, nrows, ncols, 1); + }); + doit(tk, "inv1b1", nrows, ncols, false, [&](const complex_t* in, complex_t* out) { + idft->inv1b(in, out, nrows, ncols, 1); + }); + + + std::cerr << tk.summary() << std::endl; + + return 0; +} From a46fd9a3be5f39508660583a18c26404bc2819ce Mon Sep 17 00:00:00 2001 From: Brett Viren Date: Thu, 18 Nov 2021 13:08:40 -0500 Subject: [PATCH 19/46] Improve timing measurements --- aux/test/aux_test_dft_helpers.h | 86 +++++++++++++++-- aux/test/test_idft_bench.cxx | 158 ++++++++++++++++++-------------- 2 files changed, 165 insertions(+), 79 deletions(-) diff --git a/aux/test/aux_test_dft_helpers.h b/aux/test/aux_test_dft_helpers.h index d83754976..db83d0fa8 100644 --- a/aux/test/aux_test_dft_helpers.h +++ b/aux/test/aux_test_dft_helpers.h @@ -9,11 +9,70 @@ #include "WireCellIface/IConfigurable.h" #include "WireCellIface/IDFT.h" +#include // std::clock +#include + +#include #include +#include #include +// note: likely will move in the future. +#include "custard/nlohmann/json.hpp" + namespace WireCell::Aux::Test { + using object_t = nlohmann::json; + + // probably move this to util + struct Stopwatch { + using clock = std::chrono::high_resolution_clock; + using time_point = clock::time_point; + using function_type = std::function; + + std::clock_t c_ini = std::clock(); + time_point t_ini = clock::now(); + + object_t results; + + Stopwatch(const object_t& first = object_t{}) { + (*this)([](){}, first); + } + + // Run the func, add timing info to a "stopwatch" attribute of + // data and save data to results. A pair of clock objects are + // saved, "clock" (std::clock) and "time" (std::chrono). Each + // have "start" and "elapsed" which are the number of + // nanoseconds from creation of stopwatch and for just this + // job, respectively. + void operator()(function_type func, object_t data = object_t{}) + { + auto c_now =std::clock(); + auto t_now = clock::now(); + func(); + auto c_fin =std::clock(); + auto t_fin = clock::now(); + + double dc_now = 1e9 * (c_now - c_ini) / ((double) CLOCKS_PER_SEC); + double dc_fin = 1e9 * (c_fin - c_now) / ((double) CLOCKS_PER_SEC); + double dt_now = std::chrono::duration_cast(t_now - t_ini).count(); + double dt_fin = std::chrono::duration_cast(t_fin - t_now).count(); + + data["stopwatch"]["clock"]["start"] = dc_now; + data["stopwatch"]["clock"]["elapsed"] = dc_fin; + data["stopwatch"]["time"]["start"] = dt_now; + data["stopwatch"]["time"]["elapsed"] = dt_fin; + + results.push_back(data); + } + + void save(const std::string& jsonfile) { + std::ofstream fp(jsonfile.c_str()); + fp << results.dump(4) << std::endl; + } + + + }; // fixme: add support for config IDFT::pointer make_dft(const std::string& tn="FftwDFT", @@ -37,16 +96,23 @@ namespace WireCell::Aux::Test { } return idft; } - IDFT::pointer make_dft_args(int argc, char* argv[]) - { - std::string dft_tn="FftwDFT"; - std::string dft_pi="WireCellAux"; - if (argc > 1) dft_tn = argv[1]; - if (argc > 2) dft_pi = argv[2]; + struct DftArgs { + std::string tn{"FftwDFT"}; + std::string pi{"WireCellAux"}; + std::string cfg_name{""}; Configuration cfg; + }; + + DftArgs make_dft_args(int argc, char* argv[]) + { + DftArgs ret; + + if (argc > 1) ret.tn = argv[1]; + if (argc > 2) ret.pi = argv[2]; if (argc > 3) { // Either we get directly a "data" object - cfg = Persist::load(argv[3]); + ret.cfg_name = argv[3]; + auto cfg = Persist::load(argv[3]); // or we go searching a list for matching type/name. if (cfg.isArray()) { for (auto one : cfg) { @@ -55,15 +121,17 @@ namespace WireCell::Aux::Test { if (not n.empty()) { tn = tn + ":" + n; } - if (tn == dft_tn) { + if (tn == ret.tn) { cfg = one["data"]; break; } } } + ret.cfg = cfg; } - return make_dft(dft_tn, dft_pi); + return ret; + //return make_dft(dft_tn, dft_pi, cfg); } const double default_eps = 1e-8; diff --git a/aux/test/test_idft_bench.cxx b/aux/test/test_idft_bench.cxx index 01b10d91d..3d223a787 100644 --- a/aux/test/test_idft_bench.cxx +++ b/aux/test/test_idft_bench.cxx @@ -2,26 +2,14 @@ A simple benchmark of IDFT for payloads relevant to WCT */ -#include "WireCellUtil/TimeKeeper.h" - #include "aux_test_dft_helpers.h" -#include -#include -#include - using namespace WireCell; using namespace WireCell::Aux::Test; using benchmark_function = std::function; using complex_t = std::complex; -void timeit(TimeKeeper& tk, const std::string& msg, benchmark_function func) -{ - tk("\tINIT\t" + msg); - func(); - tk("\tFINI\t" + msg); -} // benchmarks span outer product of: // - in-place / out-place @@ -31,70 +19,100 @@ void timeit(TimeKeeper& tk, const std::string& msg, benchmark_function func) using transform_function = std::function; -const int onedfull = 100'000'000; -void doit(TimeKeeper& tk, const std::string& name, int nrows, int ncols, bool inplace, transform_function func) +const int nominal = 100'000'000; +void doit(Stopwatch& sw, const std::string& name, int nrows, int ncols, transform_function func) { const int size = nrows*ncols; - const int ntimes = onedfull / size; - std::stringstream ss; - ss << "\t(" << nrows << "," << ncols << ")\t" << ntimes << "\t"; - std::string s = ss.str(); - - if (inplace) { - timeit(tk, s + "in-place\t" + name, [&]() { - std::vector in(size); - for (int count=0; count in(size), out(size); - for (int count=0; count in(size), out(size); + + sw([&](){func(in.data(), in.data());}, { + {"nrows",nrows}, {"ncols",ncols}, {"func",name}, {"ntimes",1}, {"first",true}, {"in-place",true}, + }); + + sw([&](){ + for (int count=0; countfwd1d(in, out, size); - }); - - doit(tk, "inv1d", 1, size, false, [&](const complex_t* in, complex_t* out) { - idft->inv1d(in, out, size); - }); - - int nrows = 1000; - int ncols = 1000; - doit(tk, "fwd2d", nrows, ncols, false, [&](const complex_t* in, complex_t* out) { - idft->fwd2d(in, out, nrows, ncols); - }); - doit(tk, "inv2d", nrows, ncols, false, [&](const complex_t* in, complex_t* out) { - idft->inv2d(in, out, nrows, ncols); - }); - - doit(tk, "fwd1b0", nrows, ncols, false, [&](const complex_t* in, complex_t* out) { - idft->fwd1b(in, out, nrows, ncols, 0); - }); - doit(tk, "inv1b0", nrows, ncols, false, [&](const complex_t* in, complex_t* out) { - idft->inv1b(in, out, nrows, ncols, 0); - }); - doit(tk, "fwd1b1", nrows, ncols, false, [&](const complex_t* in, complex_t* out) { - idft->fwd1b(in, out, nrows, ncols, 1); - }); - doit(tk, "inv1b1", nrows, ncols, false, [&](const complex_t* in, complex_t* out) { - idft->inv1b(in, out, nrows, ncols, 1); - }); - - - std::cerr << tk.summary() << std::endl; + auto args = make_dft_args(argc, argv); + auto idft = make_dft(args.tn, args.pi, args.cfg); + + Stopwatch sw({ + {"typename",args.tn}, + {"plugin",args.pi}, + {"config", object_t::parse(Persist::dumps(args.cfg))}, + {"config_file",args.cfg_name}}); + + std::string cname = args.cfg_name; + auto slash = cname.rfind("/"); + if (slash != std::string::npos) { + cname = cname.substr(slash+1); + } + cname = cname.substr(0, cname.rfind(".")); + std::string fname = argv[0]; + fname += "_" + args.pi + "_" + args.tn + "_" + cname + ".json"; + std::cerr << "writing to: " << fname << std::endl; + + + std::vector oned_sizes{500, 512, 1000, 1024, 4096, 6000, 8192, 10000, 16384}; + for (auto size : oned_sizes) { + std::cerr << "1d " << size << std::endl; + doit(sw, "fwd1d", 1, size, [&](const complex_t* in, complex_t* out) { + idft->fwd1d(in, out, size); + }); + + doit(sw, "inv1d", 1, size, [&](const complex_t* in, complex_t* out) { + idft->inv1d(in, out, size); + }); + } + + // channel count from some detectors plus powers of 2 + std::vector twod_nrows{800, 960, 1024, 2048, 2400, 3456, 4096}; + // tick count from some detectors plus powers of 2 + std::vector twod_ncols{2000, 4096, 6000, 8192, 9375, 9595, 9600, 10000, 16384}; + for (int nrows : twod_nrows) { + for (int ncols : twod_ncols) { + std::cerr << "2d (" << nrows << "," << ncols << ")\n"; + doit(sw, "fwd2d", nrows, ncols, [&](const complex_t* in, complex_t* out) { + idft->fwd2d(in, out, nrows, ncols); + }); + doit(sw, "inv2d", nrows, ncols, [&](const complex_t* in, complex_t* out) { + idft->inv2d(in, out, nrows, ncols); + }); + + doit(sw, "fwd1b0", nrows, ncols, [&](const complex_t* in, complex_t* out) { + idft->fwd1b(in, out, nrows, ncols, 0); + }); + doit(sw, "inv1b0", nrows, ncols, [&](const complex_t* in, complex_t* out) { + idft->inv1b(in, out, nrows, ncols, 0); + }); + doit(sw, "fwd1b1", nrows, ncols, [&](const complex_t* in, complex_t* out) { + idft->fwd1b(in, out, nrows, ncols, 1); + }); + doit(sw, "inv1b1", nrows, ncols, [&](const complex_t* in, complex_t* out) { + idft->inv1b(in, out, nrows, ncols, 1); + }); + } + } + + sw.save(fname); return 0; } From 3c65ffb21a0b298ea56663a423ae2200e9d3c91e Mon Sep 17 00:00:00 2001 From: Brett Viren Date: Thu, 18 Nov 2021 14:21:47 -0500 Subject: [PATCH 20/46] Rename, too slow to run each time --- ...st_idft_bench.cxx => check_idft_bench.cxx} | 55 ++++++++++--------- 1 file changed, 28 insertions(+), 27 deletions(-) rename aux/test/{test_idft_bench.cxx => check_idft_bench.cxx} (67%) diff --git a/aux/test/test_idft_bench.cxx b/aux/test/check_idft_bench.cxx similarity index 67% rename from aux/test/test_idft_bench.cxx rename to aux/test/check_idft_bench.cxx index 3d223a787..b7bd29100 100644 --- a/aux/test/test_idft_bench.cxx +++ b/aux/test/check_idft_bench.cxx @@ -24,6 +24,8 @@ void doit(Stopwatch& sw, const std::string& name, int nrows, int ncols, transfor { const int size = nrows*ncols; const int ntimes = std::max(1, nominal / size); + std::cerr << name << ": (" << nrows << "," << ncols << ") x "< in(size), out(size); sw([&](){func(in.data(), in.data());}, { @@ -73,7 +75,6 @@ int main(int argc, char* argv[]) std::vector oned_sizes{500, 512, 1000, 1024, 4096, 6000, 8192, 10000, 16384}; for (auto size : oned_sizes) { - std::cerr << "1d " << size << std::endl; doit(sw, "fwd1d", 1, size, [&](const complex_t* in, complex_t* out) { idft->fwd1d(in, out, size); }); @@ -84,32 +85,32 @@ int main(int argc, char* argv[]) } // channel count from some detectors plus powers of 2 - std::vector twod_nrows{800, 960, 1024, 2048, 2400, 3456, 4096}; - // tick count from some detectors plus powers of 2 - std::vector twod_ncols{2000, 4096, 6000, 8192, 9375, 9595, 9600, 10000, 16384}; - for (int nrows : twod_nrows) { - for (int ncols : twod_ncols) { - std::cerr << "2d (" << nrows << "," << ncols << ")\n"; - doit(sw, "fwd2d", nrows, ncols, [&](const complex_t* in, complex_t* out) { - idft->fwd2d(in, out, nrows, ncols); - }); - doit(sw, "inv2d", nrows, ncols, [&](const complex_t* in, complex_t* out) { - idft->inv2d(in, out, nrows, ncols); - }); - - doit(sw, "fwd1b0", nrows, ncols, [&](const complex_t* in, complex_t* out) { - idft->fwd1b(in, out, nrows, ncols, 0); - }); - doit(sw, "inv1b0", nrows, ncols, [&](const complex_t* in, complex_t* out) { - idft->inv1b(in, out, nrows, ncols, 0); - }); - doit(sw, "fwd1b1", nrows, ncols, [&](const complex_t* in, complex_t* out) { - idft->fwd1b(in, out, nrows, ncols, 1); - }); - doit(sw, "inv1b1", nrows, ncols, [&](const complex_t* in, complex_t* out) { - idft->inv1b(in, out, nrows, ncols, 1); - }); - } + std::vector> twod_sizes{ + {800,6000}, {960,6000}, // protodune u/v and w 3ms + {2400, 9595}, {3456, 9595}, // uboone u/v daq size + {1024, 1024}, {2048, 2048}, {4096, 4096}, // perfect powers of 2 + }; + for (auto& [nrows,ncols] : twod_sizes) { + + doit(sw, "fwd2d", nrows, ncols, [&](const complex_t* in, complex_t* out) { + idft->fwd2d(in, out, nrows, ncols); + }); + doit(sw, "inv2d", nrows, ncols, [&](const complex_t* in, complex_t* out) { + idft->inv2d(in, out, nrows, ncols); + }); + + doit(sw, "fwd1b0", nrows, ncols, [&](const complex_t* in, complex_t* out) { + idft->fwd1b(in, out, nrows, ncols, 0); + }); + doit(sw, "inv1b0", nrows, ncols, [&](const complex_t* in, complex_t* out) { + idft->inv1b(in, out, nrows, ncols, 0); + }); + doit(sw, "fwd1b1", nrows, ncols, [&](const complex_t* in, complex_t* out) { + idft->fwd1b(in, out, nrows, ncols, 1); + }); + doit(sw, "inv1b1", nrows, ncols, [&](const complex_t* in, complex_t* out) { + idft->inv1b(in, out, nrows, ncols, 1); + }); } sw.save(fname); From 365c21c7fb7f93740f9ab96de1fb6dab1a4ea3f8 Mon Sep 17 00:00:00 2001 From: Brett Viren Date: Fri, 19 Nov 2021 13:44:12 -0500 Subject: [PATCH 21/46] Make this more globally accessible so test/check programs can use it --- util/inc/WireCellUtil/CLI11.hpp | 9066 +++++++++++++++++++++++++++++++ 1 file changed, 9066 insertions(+) create mode 100644 util/inc/WireCellUtil/CLI11.hpp diff --git a/util/inc/WireCellUtil/CLI11.hpp b/util/inc/WireCellUtil/CLI11.hpp new file mode 100644 index 000000000..dcb57c6c6 --- /dev/null +++ b/util/inc/WireCellUtil/CLI11.hpp @@ -0,0 +1,9066 @@ +// CLI11: Version 2.1.2 +// Originally designed by Henry Schreiner +// https://github.com/CLIUtils/CLI11 +// +// This is a standalone header file generated by MakeSingleHeader.py in CLI11/scripts +// from: v2.1.2 +// +// CLI11 2.1.2 Copyright (c) 2017-2021 University of Cincinnati, developed by Henry +// Schreiner under NSF AWARD 1414736. All rights reserved. +// +// Redistribution and use in source and binary forms of CLI11, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// 3. Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +// ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +// Standard combined includes: +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#define CLI11_VERSION_MAJOR 2 +#define CLI11_VERSION_MINOR 1 +#define CLI11_VERSION_PATCH 2 +#define CLI11_VERSION "2.1.2" + + + + +// The following version macro is very similar to the one in pybind11 +#if !(defined(_MSC_VER) && __cplusplus == 199711L) && !defined(__INTEL_COMPILER) +#if __cplusplus >= 201402L +#define CLI11_CPP14 +#if __cplusplus >= 201703L +#define CLI11_CPP17 +#if __cplusplus > 201703L +#define CLI11_CPP20 +#endif +#endif +#endif +#elif defined(_MSC_VER) && __cplusplus == 199711L +// MSVC sets _MSVC_LANG rather than __cplusplus (supposedly until the standard is fully implemented) +// Unless you use the /Zc:__cplusplus flag on Visual Studio 2017 15.7 Preview 3 or newer +#if _MSVC_LANG >= 201402L +#define CLI11_CPP14 +#if _MSVC_LANG > 201402L && _MSC_VER >= 1910 +#define CLI11_CPP17 +#if __MSVC_LANG > 201703L && _MSC_VER >= 1910 +#define CLI11_CPP20 +#endif +#endif +#endif +#endif + +#if defined(CLI11_CPP14) +#define CLI11_DEPRECATED(reason) [[deprecated(reason)]] +#elif defined(_MSC_VER) +#define CLI11_DEPRECATED(reason) __declspec(deprecated(reason)) +#else +#define CLI11_DEPRECATED(reason) __attribute__((deprecated(reason))) +#endif + + + + +// C standard library +// Only needed for existence checking +#if defined CLI11_CPP17 && defined __has_include && !defined CLI11_HAS_FILESYSTEM +#if __has_include() +// Filesystem cannot be used if targeting macOS < 10.15 +#if defined __MAC_OS_X_VERSION_MIN_REQUIRED && __MAC_OS_X_VERSION_MIN_REQUIRED < 101500 +#define CLI11_HAS_FILESYSTEM 0 +#else +#include +#if defined __cpp_lib_filesystem && __cpp_lib_filesystem >= 201703 +#if defined _GLIBCXX_RELEASE && _GLIBCXX_RELEASE >= 9 +#define CLI11_HAS_FILESYSTEM 1 +#elif defined(__GLIBCXX__) +// if we are using gcc and Version <9 default to no filesystem +#define CLI11_HAS_FILESYSTEM 0 +#else +#define CLI11_HAS_FILESYSTEM 1 +#endif +#else +#define CLI11_HAS_FILESYSTEM 0 +#endif +#endif +#endif +#endif + +#if defined CLI11_HAS_FILESYSTEM && CLI11_HAS_FILESYSTEM > 0 +#include // NOLINT(build/include) +#else +#include +#include +#endif + + + +namespace CLI { + + +/// Include the items in this namespace to get free conversion of enums to/from streams. +/// (This is available inside CLI as well, so CLI11 will use this without a using statement). +namespace enums { + +/// output streaming for enumerations +template ::value>::type> +std::ostream &operator<<(std::ostream &in, const T &item) { + // make sure this is out of the detail namespace otherwise it won't be found when needed + return in << static_cast::type>(item); +} + +} // namespace enums + +/// Export to CLI namespace +using enums::operator<<; + +namespace detail { +/// a constant defining an expected max vector size defined to be a big number that could be multiplied by 4 and not +/// produce overflow for some expected uses +constexpr int expected_max_vector_size{1 << 29}; +// Based on http://stackoverflow.com/questions/236129/split-a-string-in-c +/// Split a string by a delim +inline std::vector split(const std::string &s, char delim) { + std::vector elems; + // Check to see if empty string, give consistent result + if(s.empty()) { + elems.emplace_back(); + } else { + std::stringstream ss; + ss.str(s); + std::string item; + while(std::getline(ss, item, delim)) { + elems.push_back(item); + } + } + return elems; +} + +/// Simple function to join a string +template std::string join(const T &v, std::string delim = ",") { + std::ostringstream s; + auto beg = std::begin(v); + auto end = std::end(v); + if(beg != end) + s << *beg++; + while(beg != end) { + s << delim << *beg++; + } + return s.str(); +} + +/// Simple function to join a string from processed elements +template ::value>::type> +std::string join(const T &v, Callable func, std::string delim = ",") { + std::ostringstream s; + auto beg = std::begin(v); + auto end = std::end(v); + auto loc = s.tellp(); + while(beg != end) { + auto nloc = s.tellp(); + if(nloc > loc) { + s << delim; + loc = nloc; + } + s << func(*beg++); + } + return s.str(); +} + +/// Join a string in reverse order +template std::string rjoin(const T &v, std::string delim = ",") { + std::ostringstream s; + for(std::size_t start = 0; start < v.size(); start++) { + if(start > 0) + s << delim; + s << v[v.size() - start - 1]; + } + return s.str(); +} + +// Based roughly on http://stackoverflow.com/questions/25829143/c-trim-whitespace-from-a-string + +/// Trim whitespace from left of string +inline std::string <rim(std::string &str) { + auto it = std::find_if(str.begin(), str.end(), [](char ch) { return !std::isspace(ch, std::locale()); }); + str.erase(str.begin(), it); + return str; +} + +/// Trim anything from left of string +inline std::string <rim(std::string &str, const std::string &filter) { + auto it = std::find_if(str.begin(), str.end(), [&filter](char ch) { return filter.find(ch) == std::string::npos; }); + str.erase(str.begin(), it); + return str; +} + +/// Trim whitespace from right of string +inline std::string &rtrim(std::string &str) { + auto it = std::find_if(str.rbegin(), str.rend(), [](char ch) { return !std::isspace(ch, std::locale()); }); + str.erase(it.base(), str.end()); + return str; +} + +/// Trim anything from right of string +inline std::string &rtrim(std::string &str, const std::string &filter) { + auto it = + std::find_if(str.rbegin(), str.rend(), [&filter](char ch) { return filter.find(ch) == std::string::npos; }); + str.erase(it.base(), str.end()); + return str; +} + +/// Trim whitespace from string +inline std::string &trim(std::string &str) { return ltrim(rtrim(str)); } + +/// Trim anything from string +inline std::string &trim(std::string &str, const std::string filter) { return ltrim(rtrim(str, filter), filter); } + +/// Make a copy of the string and then trim it +inline std::string trim_copy(const std::string &str) { + std::string s = str; + return trim(s); +} + +/// remove quotes at the front and back of a string either '"' or '\'' +inline std::string &remove_quotes(std::string &str) { + if(str.length() > 1 && (str.front() == '"' || str.front() == '\'')) { + if(str.front() == str.back()) { + str.pop_back(); + str.erase(str.begin(), str.begin() + 1); + } + } + return str; +} + +/// Add a leader to the beginning of all new lines (nothing is added +/// at the start of the first line). `"; "` would be for ini files +/// +/// Can't use Regex, or this would be a subs. +inline std::string fix_newlines(const std::string &leader, std::string input) { + std::string::size_type n = 0; + while(n != std::string::npos && n < input.size()) { + n = input.find('\n', n); + if(n != std::string::npos) { + input = input.substr(0, n + 1) + leader + input.substr(n + 1); + n += leader.size(); + } + } + return input; +} + +/// Make a copy of the string and then trim it, any filter string can be used (any char in string is filtered) +inline std::string trim_copy(const std::string &str, const std::string &filter) { + std::string s = str; + return trim(s, filter); +} +/// Print a two part "help" string +inline std::ostream &format_help(std::ostream &out, std::string name, const std::string &description, std::size_t wid) { + name = " " + name; + out << std::setw(static_cast(wid)) << std::left << name; + if(!description.empty()) { + if(name.length() >= wid) + out << "\n" << std::setw(static_cast(wid)) << ""; + for(const char c : description) { + out.put(c); + if(c == '\n') { + out << std::setw(static_cast(wid)) << ""; + } + } + } + out << "\n"; + return out; +} + +/// Print subcommand aliases +inline std::ostream &format_aliases(std::ostream &out, const std::vector &aliases, std::size_t wid) { + if(!aliases.empty()) { + out << std::setw(static_cast(wid)) << " aliases: "; + bool front = true; + for(const auto &alias : aliases) { + if(!front) { + out << ", "; + } else { + front = false; + } + out << detail::fix_newlines(" ", alias); + } + out << "\n"; + } + return out; +} + +/// Verify the first character of an option +/// - is a trigger character, ! has special meaning and new lines would just be annoying to deal with +template bool valid_first_char(T c) { return ((c != '-') && (c != '!') && (c != ' ') && c != '\n'); } + +/// Verify following characters of an option +template bool valid_later_char(T c) { + // = and : are value separators, { has special meaning for option defaults, + // and \n would just be annoying to deal with in many places allowing space here has too much potential for + // inadvertent entry errors and bugs + return ((c != '=') && (c != ':') && (c != '{') && (c != ' ') && c != '\n'); +} + +/// Verify an option/subcommand name +inline bool valid_name_string(const std::string &str) { + if(str.empty() || !valid_first_char(str[0])) { + return false; + } + auto e = str.end(); + for(auto c = str.begin() + 1; c != e; ++c) + if(!valid_later_char(*c)) + return false; + return true; +} + +/// Verify an app name +inline bool valid_alias_name_string(const std::string &str) { + static const std::string badChars(std::string("\n") + '\0'); + return (str.find_first_of(badChars) == std::string::npos); +} + +/// check if a string is a container segment separator (empty or "%%") +inline bool is_separator(const std::string &str) { + static const std::string sep("%%"); + return (str.empty() || str == sep); +} + +/// Verify that str consists of letters only +inline bool isalpha(const std::string &str) { + return std::all_of(str.begin(), str.end(), [](char c) { return std::isalpha(c, std::locale()); }); +} + +/// Return a lower case version of a string +inline std::string to_lower(std::string str) { + std::transform(std::begin(str), std::end(str), std::begin(str), [](const std::string::value_type &x) { + return std::tolower(x, std::locale()); + }); + return str; +} + +/// remove underscores from a string +inline std::string remove_underscore(std::string str) { + str.erase(std::remove(std::begin(str), std::end(str), '_'), std::end(str)); + return str; +} + +/// Find and replace a substring with another substring +inline std::string find_and_replace(std::string str, std::string from, std::string to) { + + std::size_t start_pos = 0; + + while((start_pos = str.find(from, start_pos)) != std::string::npos) { + str.replace(start_pos, from.length(), to); + start_pos += to.length(); + } + + return str; +} + +/// check if the flag definitions has possible false flags +inline bool has_default_flag_values(const std::string &flags) { + return (flags.find_first_of("{!") != std::string::npos); +} + +inline void remove_default_flag_values(std::string &flags) { + auto loc = flags.find_first_of('{', 2); + while(loc != std::string::npos) { + auto finish = flags.find_first_of("},", loc + 1); + if((finish != std::string::npos) && (flags[finish] == '}')) { + flags.erase(flags.begin() + static_cast(loc), + flags.begin() + static_cast(finish) + 1); + } + loc = flags.find_first_of('{', loc + 1); + } + flags.erase(std::remove(flags.begin(), flags.end(), '!'), flags.end()); +} + +/// Check if a string is a member of a list of strings and optionally ignore case or ignore underscores +inline std::ptrdiff_t find_member(std::string name, + const std::vector names, + bool ignore_case = false, + bool ignore_underscore = false) { + auto it = std::end(names); + if(ignore_case) { + if(ignore_underscore) { + name = detail::to_lower(detail::remove_underscore(name)); + it = std::find_if(std::begin(names), std::end(names), [&name](std::string local_name) { + return detail::to_lower(detail::remove_underscore(local_name)) == name; + }); + } else { + name = detail::to_lower(name); + it = std::find_if(std::begin(names), std::end(names), [&name](std::string local_name) { + return detail::to_lower(local_name) == name; + }); + } + + } else if(ignore_underscore) { + name = detail::remove_underscore(name); + it = std::find_if(std::begin(names), std::end(names), [&name](std::string local_name) { + return detail::remove_underscore(local_name) == name; + }); + } else { + it = std::find(std::begin(names), std::end(names), name); + } + + return (it != std::end(names)) ? (it - std::begin(names)) : (-1); +} + +/// Find a trigger string and call a modify callable function that takes the current string and starting position of the +/// trigger and returns the position in the string to search for the next trigger string +template inline std::string find_and_modify(std::string str, std::string trigger, Callable modify) { + std::size_t start_pos = 0; + while((start_pos = str.find(trigger, start_pos)) != std::string::npos) { + start_pos = modify(str, start_pos); + } + return str; +} + +/// Split a string '"one two" "three"' into 'one two', 'three' +/// Quote characters can be ` ' or " +inline std::vector split_up(std::string str, char delimiter = '\0') { + + const std::string delims("\'\"`"); + auto find_ws = [delimiter](char ch) { + return (delimiter == '\0') ? (std::isspace(ch, std::locale()) != 0) : (ch == delimiter); + }; + trim(str); + + std::vector output; + bool embeddedQuote = false; + char keyChar = ' '; + while(!str.empty()) { + if(delims.find_first_of(str[0]) != std::string::npos) { + keyChar = str[0]; + auto end = str.find_first_of(keyChar, 1); + while((end != std::string::npos) && (str[end - 1] == '\\')) { // deal with escaped quotes + end = str.find_first_of(keyChar, end + 1); + embeddedQuote = true; + } + if(end != std::string::npos) { + output.push_back(str.substr(1, end - 1)); + if(end + 2 < str.size()) { + str = str.substr(end + 2); + } else { + str.clear(); + } + + } else { + output.push_back(str.substr(1)); + str = ""; + } + } else { + auto it = std::find_if(std::begin(str), std::end(str), find_ws); + if(it != std::end(str)) { + std::string value = std::string(str.begin(), it); + output.push_back(value); + str = std::string(it + 1, str.end()); + } else { + output.push_back(str); + str = ""; + } + } + // transform any embedded quotes into the regular character + if(embeddedQuote) { + output.back() = find_and_replace(output.back(), std::string("\\") + keyChar, std::string(1, keyChar)); + embeddedQuote = false; + } + trim(str); + } + return output; +} + +/// This function detects an equal or colon followed by an escaped quote after an argument +/// then modifies the string to replace the equality with a space. This is needed +/// to allow the split up function to work properly and is intended to be used with the find_and_modify function +/// the return value is the offset+1 which is required by the find_and_modify function. +inline std::size_t escape_detect(std::string &str, std::size_t offset) { + auto next = str[offset + 1]; + if((next == '\"') || (next == '\'') || (next == '`')) { + auto astart = str.find_last_of("-/ \"\'`", offset - 1); + if(astart != std::string::npos) { + if(str[astart] == ((str[offset] == '=') ? '-' : '/')) + str[offset] = ' '; // interpret this as a space so the split_up works properly + } + } + return offset + 1; +} + +/// Add quotes if the string contains spaces +inline std::string &add_quotes_if_needed(std::string &str) { + if((str.front() != '"' && str.front() != '\'') || str.front() != str.back()) { + char quote = str.find('"') < str.find('\'') ? '\'' : '"'; + if(str.find(' ') != std::string::npos) { + str.insert(0, 1, quote); + str.append(1, quote); + } + } + return str; +} + +} // namespace detail + + + + +// Use one of these on all error classes. +// These are temporary and are undef'd at the end of this file. +#define CLI11_ERROR_DEF(parent, name) \ + protected: \ + name(std::string ename, std::string msg, int exit_code) : parent(std::move(ename), std::move(msg), exit_code) {} \ + name(std::string ename, std::string msg, ExitCodes exit_code) \ + : parent(std::move(ename), std::move(msg), exit_code) {} \ + \ + public: \ + name(std::string msg, ExitCodes exit_code) : parent(#name, std::move(msg), exit_code) {} \ + name(std::string msg, int exit_code) : parent(#name, std::move(msg), exit_code) {} + +// This is added after the one above if a class is used directly and builds its own message +#define CLI11_ERROR_SIMPLE(name) \ + explicit name(std::string msg) : name(#name, msg, ExitCodes::name) {} + +/// These codes are part of every error in CLI. They can be obtained from e using e.exit_code or as a quick shortcut, +/// int values from e.get_error_code(). +enum class ExitCodes { + Success = 0, + IncorrectConstruction = 100, + BadNameString, + OptionAlreadyAdded, + FileError, + ConversionError, + ValidationError, + RequiredError, + RequiresError, + ExcludesError, + ExtrasError, + ConfigError, + InvalidError, + HorribleError, + OptionNotFound, + ArgumentMismatch, + BaseClass = 127 +}; + +// Error definitions + +/// @defgroup error_group Errors +/// @brief Errors thrown by CLI11 +/// +/// These are the errors that can be thrown. Some of them, like CLI::Success, are not really errors. +/// @{ + +/// All errors derive from this one +class Error : public std::runtime_error { + int actual_exit_code; + std::string error_name{"Error"}; + + public: + int get_exit_code() const { return actual_exit_code; } + + std::string get_name() const { return error_name; } + + Error(std::string name, std::string msg, int exit_code = static_cast(ExitCodes::BaseClass)) + : runtime_error(msg), actual_exit_code(exit_code), error_name(std::move(name)) {} + + Error(std::string name, std::string msg, ExitCodes exit_code) : Error(name, msg, static_cast(exit_code)) {} +}; + +// Note: Using Error::Error constructors does not work on GCC 4.7 + +/// Construction errors (not in parsing) +class ConstructionError : public Error { + CLI11_ERROR_DEF(Error, ConstructionError) +}; + +/// Thrown when an option is set to conflicting values (non-vector and multi args, for example) +class IncorrectConstruction : public ConstructionError { + CLI11_ERROR_DEF(ConstructionError, IncorrectConstruction) + CLI11_ERROR_SIMPLE(IncorrectConstruction) + static IncorrectConstruction PositionalFlag(std::string name) { + return IncorrectConstruction(name + ": Flags cannot be positional"); + } + static IncorrectConstruction Set0Opt(std::string name) { + return IncorrectConstruction(name + ": Cannot set 0 expected, use a flag instead"); + } + static IncorrectConstruction SetFlag(std::string name) { + return IncorrectConstruction(name + ": Cannot set an expected number for flags"); + } + static IncorrectConstruction ChangeNotVector(std::string name) { + return IncorrectConstruction(name + ": You can only change the expected arguments for vectors"); + } + static IncorrectConstruction AfterMultiOpt(std::string name) { + return IncorrectConstruction( + name + ": You can't change expected arguments after you've changed the multi option policy!"); + } + static IncorrectConstruction MissingOption(std::string name) { + return IncorrectConstruction("Option " + name + " is not defined"); + } + static IncorrectConstruction MultiOptionPolicy(std::string name) { + return IncorrectConstruction(name + ": multi_option_policy only works for flags and exact value options"); + } +}; + +/// Thrown on construction of a bad name +class BadNameString : public ConstructionError { + CLI11_ERROR_DEF(ConstructionError, BadNameString) + CLI11_ERROR_SIMPLE(BadNameString) + static BadNameString OneCharName(std::string name) { return BadNameString("Invalid one char name: " + name); } + static BadNameString BadLongName(std::string name) { return BadNameString("Bad long name: " + name); } + static BadNameString DashesOnly(std::string name) { + return BadNameString("Must have a name, not just dashes: " + name); + } + static BadNameString MultiPositionalNames(std::string name) { + return BadNameString("Only one positional name allowed, remove: " + name); + } +}; + +/// Thrown when an option already exists +class OptionAlreadyAdded : public ConstructionError { + CLI11_ERROR_DEF(ConstructionError, OptionAlreadyAdded) + explicit OptionAlreadyAdded(std::string name) + : OptionAlreadyAdded(name + " is already added", ExitCodes::OptionAlreadyAdded) {} + static OptionAlreadyAdded Requires(std::string name, std::string other) { + return OptionAlreadyAdded(name + " requires " + other, ExitCodes::OptionAlreadyAdded); + } + static OptionAlreadyAdded Excludes(std::string name, std::string other) { + return OptionAlreadyAdded(name + " excludes " + other, ExitCodes::OptionAlreadyAdded); + } +}; + +// Parsing errors + +/// Anything that can error in Parse +class ParseError : public Error { + CLI11_ERROR_DEF(Error, ParseError) +}; + +// Not really "errors" + +/// This is a successful completion on parsing, supposed to exit +class Success : public ParseError { + CLI11_ERROR_DEF(ParseError, Success) + Success() : Success("Successfully completed, should be caught and quit", ExitCodes::Success) {} +}; + +/// -h or --help on command line +class CallForHelp : public Success { + CLI11_ERROR_DEF(Success, CallForHelp) + CallForHelp() : CallForHelp("This should be caught in your main function, see examples", ExitCodes::Success) {} +}; + +/// Usually something like --help-all on command line +class CallForAllHelp : public Success { + CLI11_ERROR_DEF(Success, CallForAllHelp) + CallForAllHelp() + : CallForAllHelp("This should be caught in your main function, see examples", ExitCodes::Success) {} +}; + +/// -v or --version on command line +class CallForVersion : public Success { + CLI11_ERROR_DEF(Success, CallForVersion) + CallForVersion() + : CallForVersion("This should be caught in your main function, see examples", ExitCodes::Success) {} +}; + +/// Does not output a diagnostic in CLI11_PARSE, but allows main() to return with a specific error code. +class RuntimeError : public ParseError { + CLI11_ERROR_DEF(ParseError, RuntimeError) + explicit RuntimeError(int exit_code = 1) : RuntimeError("Runtime error", exit_code) {} +}; + +/// Thrown when parsing an INI file and it is missing +class FileError : public ParseError { + CLI11_ERROR_DEF(ParseError, FileError) + CLI11_ERROR_SIMPLE(FileError) + static FileError Missing(std::string name) { return FileError(name + " was not readable (missing?)"); } +}; + +/// Thrown when conversion call back fails, such as when an int fails to coerce to a string +class ConversionError : public ParseError { + CLI11_ERROR_DEF(ParseError, ConversionError) + CLI11_ERROR_SIMPLE(ConversionError) + ConversionError(std::string member, std::string name) + : ConversionError("The value " + member + " is not an allowed value for " + name) {} + ConversionError(std::string name, std::vector results) + : ConversionError("Could not convert: " + name + " = " + detail::join(results)) {} + static ConversionError TooManyInputsFlag(std::string name) { + return ConversionError(name + ": too many inputs for a flag"); + } + static ConversionError TrueFalse(std::string name) { + return ConversionError(name + ": Should be true/false or a number"); + } +}; + +/// Thrown when validation of results fails +class ValidationError : public ParseError { + CLI11_ERROR_DEF(ParseError, ValidationError) + CLI11_ERROR_SIMPLE(ValidationError) + explicit ValidationError(std::string name, std::string msg) : ValidationError(name + ": " + msg) {} +}; + +/// Thrown when a required option is missing +class RequiredError : public ParseError { + CLI11_ERROR_DEF(ParseError, RequiredError) + explicit RequiredError(std::string name) : RequiredError(name + " is required", ExitCodes::RequiredError) {} + static RequiredError Subcommand(std::size_t min_subcom) { + if(min_subcom == 1) { + return RequiredError("A subcommand"); + } + return RequiredError("Requires at least " + std::to_string(min_subcom) + " subcommands", + ExitCodes::RequiredError); + } + static RequiredError + Option(std::size_t min_option, std::size_t max_option, std::size_t used, const std::string &option_list) { + if((min_option == 1) && (max_option == 1) && (used == 0)) + return RequiredError("Exactly 1 option from [" + option_list + "]"); + if((min_option == 1) && (max_option == 1) && (used > 1)) { + return RequiredError("Exactly 1 option from [" + option_list + "] is required and " + std::to_string(used) + + " were given", + ExitCodes::RequiredError); + } + if((min_option == 1) && (used == 0)) + return RequiredError("At least 1 option from [" + option_list + "]"); + if(used < min_option) { + return RequiredError("Requires at least " + std::to_string(min_option) + " options used and only " + + std::to_string(used) + "were given from [" + option_list + "]", + ExitCodes::RequiredError); + } + if(max_option == 1) + return RequiredError("Requires at most 1 options be given from [" + option_list + "]", + ExitCodes::RequiredError); + + return RequiredError("Requires at most " + std::to_string(max_option) + " options be used and " + + std::to_string(used) + "were given from [" + option_list + "]", + ExitCodes::RequiredError); + } +}; + +/// Thrown when the wrong number of arguments has been received +class ArgumentMismatch : public ParseError { + CLI11_ERROR_DEF(ParseError, ArgumentMismatch) + CLI11_ERROR_SIMPLE(ArgumentMismatch) + ArgumentMismatch(std::string name, int expected, std::size_t received) + : ArgumentMismatch(expected > 0 ? ("Expected exactly " + std::to_string(expected) + " arguments to " + name + + ", got " + std::to_string(received)) + : ("Expected at least " + std::to_string(-expected) + " arguments to " + name + + ", got " + std::to_string(received)), + ExitCodes::ArgumentMismatch) {} + + static ArgumentMismatch AtLeast(std::string name, int num, std::size_t received) { + return ArgumentMismatch(name + ": At least " + std::to_string(num) + " required but received " + + std::to_string(received)); + } + static ArgumentMismatch AtMost(std::string name, int num, std::size_t received) { + return ArgumentMismatch(name + ": At Most " + std::to_string(num) + " required but received " + + std::to_string(received)); + } + static ArgumentMismatch TypedAtLeast(std::string name, int num, std::string type) { + return ArgumentMismatch(name + ": " + std::to_string(num) + " required " + type + " missing"); + } + static ArgumentMismatch FlagOverride(std::string name) { + return ArgumentMismatch(name + " was given a disallowed flag override"); + } +}; + +/// Thrown when a requires option is missing +class RequiresError : public ParseError { + CLI11_ERROR_DEF(ParseError, RequiresError) + RequiresError(std::string curname, std::string subname) + : RequiresError(curname + " requires " + subname, ExitCodes::RequiresError) {} +}; + +/// Thrown when an excludes option is present +class ExcludesError : public ParseError { + CLI11_ERROR_DEF(ParseError, ExcludesError) + ExcludesError(std::string curname, std::string subname) + : ExcludesError(curname + " excludes " + subname, ExitCodes::ExcludesError) {} +}; + +/// Thrown when too many positionals or options are found +class ExtrasError : public ParseError { + CLI11_ERROR_DEF(ParseError, ExtrasError) + explicit ExtrasError(std::vector args) + : ExtrasError((args.size() > 1 ? "The following arguments were not expected: " + : "The following argument was not expected: ") + + detail::rjoin(args, " "), + ExitCodes::ExtrasError) {} + ExtrasError(const std::string &name, std::vector args) + : ExtrasError(name, + (args.size() > 1 ? "The following arguments were not expected: " + : "The following argument was not expected: ") + + detail::rjoin(args, " "), + ExitCodes::ExtrasError) {} +}; + +/// Thrown when extra values are found in an INI file +class ConfigError : public ParseError { + CLI11_ERROR_DEF(ParseError, ConfigError) + CLI11_ERROR_SIMPLE(ConfigError) + static ConfigError Extras(std::string item) { return ConfigError("INI was not able to parse " + item); } + static ConfigError NotConfigurable(std::string item) { + return ConfigError(item + ": This option is not allowed in a configuration file"); + } +}; + +/// Thrown when validation fails before parsing +class InvalidError : public ParseError { + CLI11_ERROR_DEF(ParseError, InvalidError) + explicit InvalidError(std::string name) + : InvalidError(name + ": Too many positional arguments with unlimited expected args", ExitCodes::InvalidError) { + } +}; + +/// This is just a safety check to verify selection and parsing match - you should not ever see it +/// Strings are directly added to this error, but again, it should never be seen. +class HorribleError : public ParseError { + CLI11_ERROR_DEF(ParseError, HorribleError) + CLI11_ERROR_SIMPLE(HorribleError) +}; + +// After parsing + +/// Thrown when counting a non-existent option +class OptionNotFound : public Error { + CLI11_ERROR_DEF(Error, OptionNotFound) + explicit OptionNotFound(std::string name) : OptionNotFound(name + " not found", ExitCodes::OptionNotFound) {} +}; + +#undef CLI11_ERROR_DEF +#undef CLI11_ERROR_SIMPLE + +/// @} + + + + +// Type tools + +// Utilities for type enabling +namespace detail { +// Based generally on https://rmf.io/cxx11/almost-static-if +/// Simple empty scoped class +enum class enabler {}; + +/// An instance to use in EnableIf +constexpr enabler dummy = {}; +} // namespace detail + +/// A copy of enable_if_t from C++14, compatible with C++11. +/// +/// We could check to see if C++14 is being used, but it does not hurt to redefine this +/// (even Google does this: https://github.com/google/skia/blob/main/include/private/SkTLogic.h) +/// It is not in the std namespace anyway, so no harm done. +template using enable_if_t = typename std::enable_if::type; + +/// A copy of std::void_t from C++17 (helper for C++11 and C++14) +template struct make_void { using type = void; }; + +/// A copy of std::void_t from C++17 - same reasoning as enable_if_t, it does not hurt to redefine +template using void_t = typename make_void::type; + +/// A copy of std::conditional_t from C++14 - same reasoning as enable_if_t, it does not hurt to redefine +template using conditional_t = typename std::conditional::type; + +/// Check to see if something is bool (fail check by default) +template struct is_bool : std::false_type {}; + +/// Check to see if something is bool (true if actually a bool) +template <> struct is_bool : std::true_type {}; + +/// Check to see if something is a shared pointer +template struct is_shared_ptr : std::false_type {}; + +/// Check to see if something is a shared pointer (True if really a shared pointer) +template struct is_shared_ptr> : std::true_type {}; + +/// Check to see if something is a shared pointer (True if really a shared pointer) +template struct is_shared_ptr> : std::true_type {}; + +/// Check to see if something is copyable pointer +template struct is_copyable_ptr { + static bool const value = is_shared_ptr::value || std::is_pointer::value; +}; + +/// This can be specialized to override the type deduction for IsMember. +template struct IsMemberType { using type = T; }; + +/// The main custom type needed here is const char * should be a string. +template <> struct IsMemberType { using type = std::string; }; + +namespace detail { + +// These are utilities for IsMember and other transforming objects + +/// Handy helper to access the element_type generically. This is not part of is_copyable_ptr because it requires that +/// pointer_traits be valid. + +/// not a pointer +template struct element_type { using type = T; }; + +template struct element_type::value>::type> { + using type = typename std::pointer_traits::element_type; +}; + +/// Combination of the element type and value type - remove pointer (including smart pointers) and get the value_type of +/// the container +template struct element_value_type { using type = typename element_type::type::value_type; }; + +/// Adaptor for set-like structure: This just wraps a normal container in a few utilities that do almost nothing. +template struct pair_adaptor : std::false_type { + using value_type = typename T::value_type; + using first_type = typename std::remove_const::type; + using second_type = typename std::remove_const::type; + + /// Get the first value (really just the underlying value) + template static auto first(Q &&pair_value) -> decltype(std::forward(pair_value)) { + return std::forward(pair_value); + } + /// Get the second value (really just the underlying value) + template static auto second(Q &&pair_value) -> decltype(std::forward(pair_value)) { + return std::forward(pair_value); + } +}; + +/// Adaptor for map-like structure (true version, must have key_type and mapped_type). +/// This wraps a mapped container in a few utilities access it in a general way. +template +struct pair_adaptor< + T, + conditional_t, void>> + : std::true_type { + using value_type = typename T::value_type; + using first_type = typename std::remove_const::type; + using second_type = typename std::remove_const::type; + + /// Get the first value (really just the underlying value) + template static auto first(Q &&pair_value) -> decltype(std::get<0>(std::forward(pair_value))) { + return std::get<0>(std::forward(pair_value)); + } + /// Get the second value (really just the underlying value) + template static auto second(Q &&pair_value) -> decltype(std::get<1>(std::forward(pair_value))) { + return std::get<1>(std::forward(pair_value)); + } +}; + +// Warning is suppressed due to "bug" in gcc<5.0 and gcc 7.0 with c++17 enabled that generates a Wnarrowing warning +// in the unevaluated context even if the function that was using this wasn't used. The standard says narrowing in +// brace initialization shouldn't be allowed but for backwards compatibility gcc allows it in some contexts. It is a +// little fuzzy what happens in template constructs and I think that was something GCC took a little while to work out. +// But regardless some versions of gcc generate a warning when they shouldn't from the following code so that should be +// suppressed +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wnarrowing" +#endif +// check for constructibility from a specific type and copy assignable used in the parse detection +template class is_direct_constructible { + template + static auto test(int, std::true_type) -> decltype( +// NVCC warns about narrowing conversions here +#ifdef __CUDACC__ +#pragma diag_suppress 2361 +#endif + TT { std::declval() } +#ifdef __CUDACC__ +#pragma diag_default 2361 +#endif + , + std::is_move_assignable()); + + template static auto test(int, std::false_type) -> std::false_type; + + template static auto test(...) -> std::false_type; + + public: + static constexpr bool value = decltype(test(0, typename std::is_constructible::type()))::value; +}; +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + +// Check for output streamability +// Based on https://stackoverflow.com/questions/22758291/how-can-i-detect-if-a-type-can-be-streamed-to-an-stdostream + +template class is_ostreamable { + template + static auto test(int) -> decltype(std::declval() << std::declval(), std::true_type()); + + template static auto test(...) -> std::false_type; + + public: + static constexpr bool value = decltype(test(0))::value; +}; + +/// Check for input streamability +template class is_istreamable { + template + static auto test(int) -> decltype(std::declval() >> std::declval(), std::true_type()); + + template static auto test(...) -> std::false_type; + + public: + static constexpr bool value = decltype(test(0))::value; +}; + +/// Check for complex +template class is_complex { + template + static auto test(int) -> decltype(std::declval().real(), std::declval().imag(), std::true_type()); + + template static auto test(...) -> std::false_type; + + public: + static constexpr bool value = decltype(test(0))::value; +}; + +/// Templated operation to get a value from a stream +template ::value, detail::enabler> = detail::dummy> +bool from_stream(const std::string &istring, T &obj) { + std::istringstream is; + is.str(istring); + is >> obj; + return !is.fail() && !is.rdbuf()->in_avail(); +} + +template ::value, detail::enabler> = detail::dummy> +bool from_stream(const std::string & /*istring*/, T & /*obj*/) { + return false; +} + +// check to see if an object is a mutable container (fail by default) +template struct is_mutable_container : std::false_type {}; + +/// type trait to test if a type is a mutable container meaning it has a value_type, it has an iterator, a clear, and +/// end methods and an insert function. And for our purposes we exclude std::string and types that can be constructed +/// from a std::string +template +struct is_mutable_container< + T, + conditional_t().end()), + decltype(std::declval().clear()), + decltype(std::declval().insert(std::declval().end())>(), + std::declval()))>, + void>> + : public conditional_t::value, std::false_type, std::true_type> {}; + +// check to see if an object is a mutable container (fail by default) +template struct is_readable_container : std::false_type {}; + +/// type trait to test if a type is a container meaning it has a value_type, it has an iterator, a clear, and an end +/// methods and an insert function. And for our purposes we exclude std::string and types that can be constructed from +/// a std::string +template +struct is_readable_container< + T, + conditional_t().end()), decltype(std::declval().begin())>, void>> + : public std::true_type {}; + +// check to see if an object is a wrapper (fail by default) +template struct is_wrapper : std::false_type {}; + +// check if an object is a wrapper (it has a value_type defined) +template +struct is_wrapper, void>> : public std::true_type {}; + +// Check for tuple like types, as in classes with a tuple_size type trait +template class is_tuple_like { + template + // static auto test(int) + // -> decltype(std::conditional<(std::tuple_size::value > 0), std::true_type, std::false_type>::type()); + static auto test(int) -> decltype(std::tuple_size::type>::value, std::true_type{}); + template static auto test(...) -> std::false_type; + + public: + static constexpr bool value = decltype(test(0))::value; +}; + +/// Convert an object to a string (directly forward if this can become a string) +template ::value, detail::enabler> = detail::dummy> +auto to_string(T &&value) -> decltype(std::forward(value)) { + return std::forward(value); +} + +/// Construct a string from the object +template ::value && !std::is_convertible::value, + detail::enabler> = detail::dummy> +std::string to_string(const T &value) { + return std::string(value); +} + +/// Convert an object to a string (streaming must be supported for that type) +template ::value && !std::is_constructible::value && + is_ostreamable::value, + detail::enabler> = detail::dummy> +std::string to_string(T &&value) { + std::stringstream stream; + stream << value; + return stream.str(); +} + +/// If conversion is not supported, return an empty string (streaming is not supported for that type) +template ::value && !is_ostreamable::value && + !is_readable_container::type>::value, + detail::enabler> = detail::dummy> +std::string to_string(T &&) { + return std::string{}; +} + +/// convert a readable container to a string +template ::value && !is_ostreamable::value && + is_readable_container::value, + detail::enabler> = detail::dummy> +std::string to_string(T &&variable) { + std::vector defaults; + auto cval = variable.begin(); + auto end = variable.end(); + while(cval != end) { + defaults.emplace_back(CLI::detail::to_string(*cval)); + ++cval; + } + return std::string("[" + detail::join(defaults) + "]"); +} + +/// special template overload +template ::value, detail::enabler> = detail::dummy> +auto checked_to_string(T &&value) -> decltype(to_string(std::forward(value))) { + return to_string(std::forward(value)); +} + +/// special template overload +template ::value, detail::enabler> = detail::dummy> +std::string checked_to_string(T &&) { + return std::string{}; +} +/// get a string as a convertible value for arithmetic types +template ::value, detail::enabler> = detail::dummy> +std::string value_string(const T &value) { + return std::to_string(value); +} +/// get a string as a convertible value for enumerations +template ::value, detail::enabler> = detail::dummy> +std::string value_string(const T &value) { + return std::to_string(static_cast::type>(value)); +} +/// for other types just use the regular to_string function +template ::value && !std::is_arithmetic::value, detail::enabler> = detail::dummy> +auto value_string(const T &value) -> decltype(to_string(value)) { + return to_string(value); +} + +/// template to get the underlying value type if it exists or use a default +template struct wrapped_type { using type = def; }; + +/// Type size for regular object types that do not look like a tuple +template struct wrapped_type::value>::type> { + using type = typename T::value_type; +}; + +/// This will only trigger for actual void type +template struct type_count_base { static const int value{0}; }; + +/// Type size for regular object types that do not look like a tuple +template +struct type_count_base::value && !is_mutable_container::value && + !std::is_void::value>::type> { + static constexpr int value{1}; +}; + +/// the base tuple size +template +struct type_count_base::value && !is_mutable_container::value>::type> { + static constexpr int value{std::tuple_size::value}; +}; + +/// Type count base for containers is the type_count_base of the individual element +template struct type_count_base::value>::type> { + static constexpr int value{type_count_base::value}; +}; + +/// Set of overloads to get the type size of an object + +/// forward declare the subtype_count structure +template struct subtype_count; + +/// forward declare the subtype_count_min structure +template struct subtype_count_min; + +/// This will only trigger for actual void type +template struct type_count { static const int value{0}; }; + +/// Type size for regular object types that do not look like a tuple +template +struct type_count::value && !is_tuple_like::value && !is_complex::value && + !std::is_void::value>::type> { + static constexpr int value{1}; +}; + +/// Type size for complex since it sometimes looks like a wrapper +template struct type_count::value>::type> { + static constexpr int value{2}; +}; + +/// Type size of types that are wrappers,except complex and tuples(which can also be wrappers sometimes) +template struct type_count::value>::type> { + static constexpr int value{subtype_count::value}; +}; + +/// Type size of types that are wrappers,except containers complex and tuples(which can also be wrappers sometimes) +template +struct type_count::value && !is_complex::value && !is_tuple_like::value && + !is_mutable_container::value>::type> { + static constexpr int value{type_count::value}; +}; + +/// 0 if the index > tuple size +template +constexpr typename std::enable_if::value, int>::type tuple_type_size() { + return 0; +} + +/// Recursively generate the tuple type name +template + constexpr typename std::enable_if < I::value, int>::type tuple_type_size() { + return subtype_count::type>::value + tuple_type_size(); +} + +/// Get the type size of the sum of type sizes for all the individual tuple types +template struct type_count::value>::type> { + static constexpr int value{tuple_type_size()}; +}; + +/// definition of subtype count +template struct subtype_count { + static constexpr int value{is_mutable_container::value ? expected_max_vector_size : type_count::value}; +}; + +/// This will only trigger for actual void type +template struct type_count_min { static const int value{0}; }; + +/// Type size for regular object types that do not look like a tuple +template +struct type_count_min< + T, + typename std::enable_if::value && !is_tuple_like::value && !is_wrapper::value && + !is_complex::value && !std::is_void::value>::type> { + static constexpr int value{type_count::value}; +}; + +/// Type size for complex since it sometimes looks like a wrapper +template struct type_count_min::value>::type> { + static constexpr int value{1}; +}; + +/// Type size min of types that are wrappers,except complex and tuples(which can also be wrappers sometimes) +template +struct type_count_min< + T, + typename std::enable_if::value && !is_complex::value && !is_tuple_like::value>::type> { + static constexpr int value{subtype_count_min::value}; +}; + +/// 0 if the index > tuple size +template +constexpr typename std::enable_if::value, int>::type tuple_type_size_min() { + return 0; +} + +/// Recursively generate the tuple type name +template + constexpr typename std::enable_if < I::value, int>::type tuple_type_size_min() { + return subtype_count_min::type>::value + tuple_type_size_min(); +} + +/// Get the type size of the sum of type sizes for all the individual tuple types +template struct type_count_min::value>::type> { + static constexpr int value{tuple_type_size_min()}; +}; + +/// definition of subtype count +template struct subtype_count_min { + static constexpr int value{is_mutable_container::value + ? ((type_count::value < expected_max_vector_size) ? type_count::value : 0) + : type_count_min::value}; +}; + +/// This will only trigger for actual void type +template struct expected_count { static const int value{0}; }; + +/// For most types the number of expected items is 1 +template +struct expected_count::value && !is_wrapper::value && + !std::is_void::value>::type> { + static constexpr int value{1}; +}; +/// number of expected items in a vector +template struct expected_count::value>::type> { + static constexpr int value{expected_max_vector_size}; +}; + +/// number of expected items in a vector +template +struct expected_count::value && is_wrapper::value>::type> { + static constexpr int value{expected_count::value}; +}; + +// Enumeration of the different supported categorizations of objects +enum class object_category : int { + char_value = 1, + integral_value = 2, + unsigned_integral = 4, + enumeration = 6, + boolean_value = 8, + floating_point = 10, + number_constructible = 12, + double_constructible = 14, + integer_constructible = 16, + // string like types + string_assignable = 23, + string_constructible = 24, + other = 45, + // special wrapper or container types + wrapper_value = 50, + complex_number = 60, + tuple_value = 70, + container_value = 80, + +}; + +/// Set of overloads to classify an object according to type + +/// some type that is not otherwise recognized +template struct classify_object { + static constexpr object_category value{object_category::other}; +}; + +/// Signed integers +template +struct classify_object< + T, + typename std::enable_if::value && !std::is_same::value && std::is_signed::value && + !is_bool::value && !std::is_enum::value>::type> { + static constexpr object_category value{object_category::integral_value}; +}; + +/// Unsigned integers +template +struct classify_object::value && std::is_unsigned::value && + !std::is_same::value && !is_bool::value>::type> { + static constexpr object_category value{object_category::unsigned_integral}; +}; + +/// single character values +template +struct classify_object::value && !std::is_enum::value>::type> { + static constexpr object_category value{object_category::char_value}; +}; + +/// Boolean values +template struct classify_object::value>::type> { + static constexpr object_category value{object_category::boolean_value}; +}; + +/// Floats +template struct classify_object::value>::type> { + static constexpr object_category value{object_category::floating_point}; +}; + +/// String and similar direct assignment +template +struct classify_object::value && !std::is_integral::value && + std::is_assignable::value>::type> { + static constexpr object_category value{object_category::string_assignable}; +}; + +/// String and similar constructible and copy assignment +template +struct classify_object< + T, + typename std::enable_if::value && !std::is_integral::value && + !std::is_assignable::value && (type_count::value == 1) && + std::is_constructible::value>::type> { + static constexpr object_category value{object_category::string_constructible}; +}; + +/// Enumerations +template struct classify_object::value>::type> { + static constexpr object_category value{object_category::enumeration}; +}; + +template struct classify_object::value>::type> { + static constexpr object_category value{object_category::complex_number}; +}; + +/// Handy helper to contain a bunch of checks that rule out many common types (integers, string like, floating point, +/// vectors, and enumerations +template struct uncommon_type { + using type = typename std::conditional::value && !std::is_integral::value && + !std::is_assignable::value && + !std::is_constructible::value && !is_complex::value && + !is_mutable_container::value && !std::is_enum::value, + std::true_type, + std::false_type>::type; + static constexpr bool value = type::value; +}; + +/// wrapper type +template +struct classify_object::value && is_wrapper::value && + !is_tuple_like::value && uncommon_type::value)>::type> { + static constexpr object_category value{object_category::wrapper_value}; +}; + +/// Assignable from double or int +template +struct classify_object::value && type_count::value == 1 && + !is_wrapper::value && is_direct_constructible::value && + is_direct_constructible::value>::type> { + static constexpr object_category value{object_category::number_constructible}; +}; + +/// Assignable from int +template +struct classify_object::value && type_count::value == 1 && + !is_wrapper::value && !is_direct_constructible::value && + is_direct_constructible::value>::type> { + static constexpr object_category value{object_category::integer_constructible}; +}; + +/// Assignable from double +template +struct classify_object::value && type_count::value == 1 && + !is_wrapper::value && is_direct_constructible::value && + !is_direct_constructible::value>::type> { + static constexpr object_category value{object_category::double_constructible}; +}; + +/// Tuple type +template +struct classify_object< + T, + typename std::enable_if::value && + ((type_count::value >= 2 && !is_wrapper::value) || + (uncommon_type::value && !is_direct_constructible::value && + !is_direct_constructible::value))>::type> { + static constexpr object_category value{object_category::tuple_value}; + // the condition on this class requires it be like a tuple, but on some compilers (like Xcode) tuples can be + // constructed from just the first element so tuples of can be constructed from a string, which + // could lead to issues so there are two variants of the condition, the first isolates things with a type size >=2 + // mainly to get tuples on Xcode with the exception of wrappers, the second is the main one and just separating out + // those cases that are caught by other object classifications +}; + +/// container type +template struct classify_object::value>::type> { + static constexpr object_category value{object_category::container_value}; +}; + +// Type name print + +/// Was going to be based on +/// http://stackoverflow.com/questions/1055452/c-get-name-of-type-in-template +/// But this is cleaner and works better in this case + +template ::value == object_category::char_value, detail::enabler> = detail::dummy> +constexpr const char *type_name() { + return "CHAR"; +} + +template ::value == object_category::integral_value || + classify_object::value == object_category::integer_constructible, + detail::enabler> = detail::dummy> +constexpr const char *type_name() { + return "INT"; +} + +template ::value == object_category::unsigned_integral, detail::enabler> = detail::dummy> +constexpr const char *type_name() { + return "UINT"; +} + +template ::value == object_category::floating_point || + classify_object::value == object_category::number_constructible || + classify_object::value == object_category::double_constructible, + detail::enabler> = detail::dummy> +constexpr const char *type_name() { + return "FLOAT"; +} + +/// Print name for enumeration types +template ::value == object_category::enumeration, detail::enabler> = detail::dummy> +constexpr const char *type_name() { + return "ENUM"; +} + +/// Print name for enumeration types +template ::value == object_category::boolean_value, detail::enabler> = detail::dummy> +constexpr const char *type_name() { + return "BOOLEAN"; +} + +/// Print name for enumeration types +template ::value == object_category::complex_number, detail::enabler> = detail::dummy> +constexpr const char *type_name() { + return "COMPLEX"; +} + +/// Print for all other types +template ::value >= object_category::string_assignable && + classify_object::value <= object_category::other, + detail::enabler> = detail::dummy> +constexpr const char *type_name() { + return "TEXT"; +} +/// typename for tuple value +template ::value == object_category::tuple_value && type_count_base::value >= 2, + detail::enabler> = detail::dummy> +std::string type_name(); // forward declaration + +/// Generate type name for a wrapper or container value +template ::value == object_category::container_value || + classify_object::value == object_category::wrapper_value, + detail::enabler> = detail::dummy> +std::string type_name(); // forward declaration + +/// Print name for single element tuple types +template ::value == object_category::tuple_value && type_count_base::value == 1, + detail::enabler> = detail::dummy> +inline std::string type_name() { + return type_name::type>::type>(); +} + +/// Empty string if the index > tuple size +template +inline typename std::enable_if::value, std::string>::type tuple_name() { + return std::string{}; +} + +/// Recursively generate the tuple type name +template +inline typename std::enable_if<(I < type_count_base::value), std::string>::type tuple_name() { + std::string str = std::string(type_name::type>::type>()) + + ',' + tuple_name(); + if(str.back() == ',') + str.pop_back(); + return str; +} + +/// Print type name for tuples with 2 or more elements +template ::value == object_category::tuple_value && type_count_base::value >= 2, + detail::enabler>> +inline std::string type_name() { + auto tname = std::string(1, '[') + tuple_name(); + tname.push_back(']'); + return tname; +} + +/// get the type name for a type that has a value_type member +template ::value == object_category::container_value || + classify_object::value == object_category::wrapper_value, + detail::enabler>> +inline std::string type_name() { + return type_name(); +} + +// Lexical cast + +/// Convert to an unsigned integral +template ::value, detail::enabler> = detail::dummy> +bool integral_conversion(const std::string &input, T &output) noexcept { + if(input.empty()) { + return false; + } + char *val = nullptr; + std::uint64_t output_ll = std::strtoull(input.c_str(), &val, 0); + output = static_cast(output_ll); + return val == (input.c_str() + input.size()) && static_cast(output) == output_ll; +} + +/// Convert to a signed integral +template ::value, detail::enabler> = detail::dummy> +bool integral_conversion(const std::string &input, T &output) noexcept { + if(input.empty()) { + return false; + } + char *val = nullptr; + std::int64_t output_ll = std::strtoll(input.c_str(), &val, 0); + output = static_cast(output_ll); + return val == (input.c_str() + input.size()) && static_cast(output) == output_ll; +} + +/// Convert a flag into an integer value typically binary flags +inline std::int64_t to_flag_value(std::string val) { + static const std::string trueString("true"); + static const std::string falseString("false"); + if(val == trueString) { + return 1; + } + if(val == falseString) { + return -1; + } + val = detail::to_lower(val); + std::int64_t ret; + if(val.size() == 1) { + if(val[0] >= '1' && val[0] <= '9') { + return (static_cast(val[0]) - '0'); + } + switch(val[0]) { + case '0': + case 'f': + case 'n': + case '-': + ret = -1; + break; + case 't': + case 'y': + case '+': + ret = 1; + break; + default: + throw std::invalid_argument("unrecognized character"); + } + return ret; + } + if(val == trueString || val == "on" || val == "yes" || val == "enable") { + ret = 1; + } else if(val == falseString || val == "off" || val == "no" || val == "disable") { + ret = -1; + } else { + ret = std::stoll(val); + } + return ret; +} + +/// Integer conversion +template ::value == object_category::integral_value || + classify_object::value == object_category::unsigned_integral, + detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + return integral_conversion(input, output); +} + +/// char values +template ::value == object_category::char_value, detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + if(input.size() == 1) { + output = static_cast(input[0]); + return true; + } + return integral_conversion(input, output); +} + +/// Boolean values +template ::value == object_category::boolean_value, detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + try { + auto out = to_flag_value(input); + output = (out > 0); + return true; + } catch(const std::invalid_argument &) { + return false; + } catch(const std::out_of_range &) { + // if the number is out of the range of a 64 bit value then it is still a number and for this purpose is still + // valid all we care about the sign + output = (input[0] != '-'); + return true; + } +} + +/// Floats +template ::value == object_category::floating_point, detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + if(input.empty()) { + return false; + } + char *val = nullptr; + auto output_ld = std::strtold(input.c_str(), &val); + output = static_cast(output_ld); + return val == (input.c_str() + input.size()); +} + +/// complex +template ::value == object_category::complex_number, detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + using XC = typename wrapped_type::type; + XC x{0.0}, y{0.0}; + auto str1 = input; + bool worked = false; + auto nloc = str1.find_last_of("+-"); + if(nloc != std::string::npos && nloc > 0) { + worked = detail::lexical_cast(str1.substr(0, nloc), x); + str1 = str1.substr(nloc); + if(str1.back() == 'i' || str1.back() == 'j') + str1.pop_back(); + worked = worked && detail::lexical_cast(str1, y); + } else { + if(str1.back() == 'i' || str1.back() == 'j') { + str1.pop_back(); + worked = detail::lexical_cast(str1, y); + x = XC{0}; + } else { + worked = detail::lexical_cast(str1, x); + y = XC{0}; + } + } + if(worked) { + output = T{x, y}; + return worked; + } + return from_stream(input, output); +} + +/// String and similar direct assignment +template ::value == object_category::string_assignable, detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + output = input; + return true; +} + +/// String and similar constructible and copy assignment +template < + typename T, + enable_if_t::value == object_category::string_constructible, detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + output = T(input); + return true; +} + +/// Enumerations +template ::value == object_category::enumeration, detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + typename std::underlying_type::type val; + if(!integral_conversion(input, val)) { + return false; + } + output = static_cast(val); + return true; +} + +/// wrapper types +template ::value == object_category::wrapper_value && + std::is_assignable::value, + detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + typename T::value_type val; + if(lexical_cast(input, val)) { + output = val; + return true; + } + return from_stream(input, output); +} + +template ::value == object_category::wrapper_value && + !std::is_assignable::value && std::is_assignable::value, + detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + typename T::value_type val; + if(lexical_cast(input, val)) { + output = T{val}; + return true; + } + return from_stream(input, output); +} + +/// Assignable from double or int +template < + typename T, + enable_if_t::value == object_category::number_constructible, detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + int val; + if(integral_conversion(input, val)) { + output = T(val); + return true; + } else { + double dval; + if(lexical_cast(input, dval)) { + output = T{dval}; + return true; + } + } + return from_stream(input, output); +} + +/// Assignable from int +template < + typename T, + enable_if_t::value == object_category::integer_constructible, detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + int val; + if(integral_conversion(input, val)) { + output = T(val); + return true; + } + return from_stream(input, output); +} + +/// Assignable from double +template < + typename T, + enable_if_t::value == object_category::double_constructible, detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + double val; + if(lexical_cast(input, val)) { + output = T{val}; + return true; + } + return from_stream(input, output); +} + +/// Non-string convertible from an int +template ::value == object_category::other && std::is_assignable::value, + detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + int val; + if(integral_conversion(input, val)) { +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4800) +#endif + // with Atomic this could produce a warning due to the conversion but if atomic gets here it is an old style + // so will most likely still work + output = val; +#ifdef _MSC_VER +#pragma warning(pop) +#endif + return true; + } + // LCOV_EXCL_START + // This version of cast is only used for odd cases in an older compilers the fail over + // from_stream is tested elsewhere an not relevant for coverage here + return from_stream(input, output); + // LCOV_EXCL_STOP +} + +/// Non-string parsable by a stream +template ::value == object_category::other && !std::is_assignable::value, + detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + static_assert(is_istreamable::value, + "option object type must have a lexical cast overload or streaming input operator(>>) defined, if it " + "is convertible from another type use the add_option(...) with XC being the known type"); + return from_stream(input, output); +} + +/// Assign a value through lexical cast operations +/// Strings can be empty so we need to do a little different +template ::value && + (classify_object::value == object_category::string_assignable || + classify_object::value == object_category::string_constructible), + detail::enabler> = detail::dummy> +bool lexical_assign(const std::string &input, AssignTo &output) { + return lexical_cast(input, output); +} + +/// Assign a value through lexical cast operations +template ::value && std::is_assignable::value && + classify_object::value != object_category::string_assignable && + classify_object::value != object_category::string_constructible, + detail::enabler> = detail::dummy> +bool lexical_assign(const std::string &input, AssignTo &output) { + if(input.empty()) { + output = AssignTo{}; + return true; + } + + return lexical_cast(input, output); +} + +/// Assign a value through lexical cast operations +template ::value && !std::is_assignable::value && + classify_object::value == object_category::wrapper_value, + detail::enabler> = detail::dummy> +bool lexical_assign(const std::string &input, AssignTo &output) { + if(input.empty()) { + typename AssignTo::value_type emptyVal{}; + output = emptyVal; + return true; + } + return lexical_cast(input, output); +} + +/// Assign a value through lexical cast operations for int compatible values +/// mainly for atomic operations on some compilers +template ::value && !std::is_assignable::value && + classify_object::value != object_category::wrapper_value && + std::is_assignable::value, + detail::enabler> = detail::dummy> +bool lexical_assign(const std::string &input, AssignTo &output) { + if(input.empty()) { + output = 0; + return true; + } + int val; + if(lexical_cast(input, val)) { + output = val; + return true; + } + return false; +} + +/// Assign a value converted from a string in lexical cast to the output value directly +template ::value && std::is_assignable::value, + detail::enabler> = detail::dummy> +bool lexical_assign(const std::string &input, AssignTo &output) { + ConvertTo val{}; + bool parse_result = (!input.empty()) ? lexical_cast(input, val) : true; + if(parse_result) { + output = val; + } + return parse_result; +} + +/// Assign a value from a lexical cast through constructing a value and move assigning it +template < + typename AssignTo, + typename ConvertTo, + enable_if_t::value && !std::is_assignable::value && + std::is_move_assignable::value, + detail::enabler> = detail::dummy> +bool lexical_assign(const std::string &input, AssignTo &output) { + ConvertTo val{}; + bool parse_result = input.empty() ? true : lexical_cast(input, val); + if(parse_result) { + output = AssignTo(val); // use () form of constructor to allow some implicit conversions + } + return parse_result; +} + +/// primary lexical conversion operation, 1 string to 1 type of some kind +template ::value <= object_category::other && + classify_object::value <= object_category::wrapper_value, + detail::enabler> = detail::dummy> +bool lexical_conversion(const std::vector &strings, AssignTo &output) { + return lexical_assign(strings[0], output); +} + +/// Lexical conversion if there is only one element but the conversion type is for two, then call a two element +/// constructor +template ::value <= 2) && expected_count::value == 1 && + is_tuple_like::value && type_count_base::value == 2, + detail::enabler> = detail::dummy> +bool lexical_conversion(const std::vector &strings, AssignTo &output) { + // the remove const is to handle pair types coming from a container + typename std::remove_const::type>::type v1; + typename std::tuple_element<1, ConvertTo>::type v2; + bool retval = lexical_assign(strings[0], v1); + if(strings.size() > 1) { + retval = retval && lexical_assign(strings[1], v2); + } + if(retval) { + output = AssignTo{v1, v2}; + } + return retval; +} + +/// Lexical conversion of a container types of single elements +template ::value && is_mutable_container::value && + type_count::value == 1, + detail::enabler> = detail::dummy> +bool lexical_conversion(const std::vector &strings, AssignTo &output) { + output.erase(output.begin(), output.end()); + for(const auto &elem : strings) { + typename AssignTo::value_type out; + bool retval = lexical_assign(elem, out); + if(!retval) { + return false; + } + output.insert(output.end(), std::move(out)); + } + return (!output.empty()); +} + +/// Lexical conversion for complex types +template ::value, detail::enabler> = detail::dummy> +bool lexical_conversion(const std::vector &strings, AssignTo &output) { + + if(strings.size() >= 2 && !strings[1].empty()) { + using XC2 = typename wrapped_type::type; + XC2 x{0.0}, y{0.0}; + auto str1 = strings[1]; + if(str1.back() == 'i' || str1.back() == 'j') { + str1.pop_back(); + } + auto worked = detail::lexical_cast(strings[0], x) && detail::lexical_cast(str1, y); + if(worked) { + output = ConvertTo{x, y}; + } + return worked; + } else { + return lexical_assign(strings[0], output); + } +} + +/// Conversion to a vector type using a particular single type as the conversion type +template ::value && (expected_count::value == 1) && + (type_count::value == 1), + detail::enabler> = detail::dummy> +bool lexical_conversion(const std::vector &strings, AssignTo &output) { + bool retval = true; + output.clear(); + output.reserve(strings.size()); + for(const auto &elem : strings) { + + output.emplace_back(); + retval = retval && lexical_assign(elem, output.back()); + } + return (!output.empty()) && retval; +} + +// forward declaration + +/// Lexical conversion of a container types with conversion type of two elements +template ::value && is_mutable_container::value && + type_count_base::value == 2, + detail::enabler> = detail::dummy> +bool lexical_conversion(std::vector strings, AssignTo &output); + +/// Lexical conversion of a vector types with type_size >2 forward declaration +template ::value && is_mutable_container::value && + type_count_base::value != 2 && + ((type_count::value > 2) || + (type_count::value > type_count_base::value)), + detail::enabler> = detail::dummy> +bool lexical_conversion(const std::vector &strings, AssignTo &output); + +/// Conversion for tuples +template ::value && is_tuple_like::value && + (type_count_base::value != type_count::value || + type_count::value > 2), + detail::enabler> = detail::dummy> +bool lexical_conversion(const std::vector &strings, AssignTo &output); // forward declaration + +/// Conversion for operations where the assigned type is some class but the conversion is a mutable container or large +/// tuple +template ::value && !is_mutable_container::value && + classify_object::value != object_category::wrapper_value && + (is_mutable_container::value || type_count::value > 2), + detail::enabler> = detail::dummy> +bool lexical_conversion(const std::vector &strings, AssignTo &output) { + + if(strings.size() > 1 || (!strings.empty() && !(strings.front().empty()))) { + ConvertTo val; + auto retval = lexical_conversion(strings, val); + output = AssignTo{val}; + return retval; + } + output = AssignTo{}; + return true; +} + +/// function template for converting tuples if the static Index is greater than the tuple size +template +inline typename std::enable_if<(I >= type_count_base::value), bool>::type +tuple_conversion(const std::vector &, AssignTo &) { + return true; +} + +/// Conversion of a tuple element where the type size ==1 and not a mutable container +template +inline typename std::enable_if::value && type_count::value == 1, bool>::type +tuple_type_conversion(std::vector &strings, AssignTo &output) { + auto retval = lexical_assign(strings[0], output); + strings.erase(strings.begin()); + return retval; +} + +/// Conversion of a tuple element where the type size !=1 but the size is fixed and not a mutable container +template +inline typename std::enable_if::value && (type_count::value > 1) && + type_count::value == type_count_min::value, + bool>::type +tuple_type_conversion(std::vector &strings, AssignTo &output) { + auto retval = lexical_conversion(strings, output); + strings.erase(strings.begin(), strings.begin() + type_count::value); + return retval; +} + +/// Conversion of a tuple element where the type is a mutable container or a type with different min and max type sizes +template +inline typename std::enable_if::value || + type_count::value != type_count_min::value, + bool>::type +tuple_type_conversion(std::vector &strings, AssignTo &output) { + + std::size_t index{subtype_count_min::value}; + const std::size_t mx_count{subtype_count::value}; + const std::size_t mx{(std::max)(mx_count, strings.size())}; + + while(index < mx) { + if(is_separator(strings[index])) { + break; + } + ++index; + } + bool retval = lexical_conversion( + std::vector(strings.begin(), strings.begin() + static_cast(index)), output); + strings.erase(strings.begin(), strings.begin() + static_cast(index) + 1); + return retval; +} + +/// Tuple conversion operation +template +inline typename std::enable_if<(I < type_count_base::value), bool>::type +tuple_conversion(std::vector strings, AssignTo &output) { + bool retval = true; + using ConvertToElement = typename std:: + conditional::value, typename std::tuple_element::type, ConvertTo>::type; + if(!strings.empty()) { + retval = retval && tuple_type_conversion::type, ConvertToElement>( + strings, std::get(output)); + } + retval = retval && tuple_conversion(std::move(strings), output); + return retval; +} + +/// Lexical conversion of a container types with tuple elements of size 2 +template ::value && is_mutable_container::value && + type_count_base::value == 2, + detail::enabler>> +bool lexical_conversion(std::vector strings, AssignTo &output) { + output.clear(); + while(!strings.empty()) { + + typename std::remove_const::type>::type v1; + typename std::tuple_element<1, typename ConvertTo::value_type>::type v2; + bool retval = tuple_type_conversion(strings, v1); + if(!strings.empty()) { + retval = retval && tuple_type_conversion(strings, v2); + } + if(retval) { + output.insert(output.end(), typename AssignTo::value_type{v1, v2}); + } else { + return false; + } + } + return (!output.empty()); +} + +/// lexical conversion of tuples with type count>2 or tuples of types of some element with a type size>=2 +template ::value && is_tuple_like::value && + (type_count_base::value != type_count::value || + type_count::value > 2), + detail::enabler>> +bool lexical_conversion(const std::vector &strings, AssignTo &output) { + static_assert( + !is_tuple_like::value || type_count_base::value == type_count_base::value, + "if the conversion type is defined as a tuple it must be the same size as the type you are converting to"); + return tuple_conversion(strings, output); +} + +/// Lexical conversion of a vector types for everything but tuples of two elements and types of size 1 +template ::value && is_mutable_container::value && + type_count_base::value != 2 && + ((type_count::value > 2) || + (type_count::value > type_count_base::value)), + detail::enabler>> +bool lexical_conversion(const std::vector &strings, AssignTo &output) { + bool retval = true; + output.clear(); + std::vector temp; + std::size_t ii{0}; + std::size_t icount{0}; + std::size_t xcm{type_count::value}; + auto ii_max = strings.size(); + while(ii < ii_max) { + temp.push_back(strings[ii]); + ++ii; + ++icount; + if(icount == xcm || is_separator(temp.back()) || ii == ii_max) { + if(static_cast(xcm) > type_count_min::value && is_separator(temp.back())) { + temp.pop_back(); + } + typename AssignTo::value_type temp_out; + retval = retval && + lexical_conversion(temp, temp_out); + temp.clear(); + if(!retval) { + return false; + } + output.insert(output.end(), std::move(temp_out)); + icount = 0; + } + } + return retval; +} + +/// conversion for wrapper types +template ::value == object_category::wrapper_value && + std::is_assignable::value, + detail::enabler> = detail::dummy> +bool lexical_conversion(const std::vector &strings, AssignTo &output) { + if(strings.empty() || strings.front().empty()) { + output = ConvertTo{}; + return true; + } + typename ConvertTo::value_type val; + if(lexical_conversion(strings, val)) { + output = ConvertTo{val}; + return true; + } + return false; +} + +/// conversion for wrapper types +template ::value == object_category::wrapper_value && + !std::is_assignable::value, + detail::enabler> = detail::dummy> +bool lexical_conversion(const std::vector &strings, AssignTo &output) { + using ConvertType = typename ConvertTo::value_type; + if(strings.empty() || strings.front().empty()) { + output = ConvertType{}; + return true; + } + ConvertType val; + if(lexical_conversion(strings, val)) { + output = val; + return true; + } + return false; +} + +/// Sum a vector of flag representations +/// The flag vector produces a series of strings in a vector, simple true is represented by a "1", simple false is +/// by +/// "-1" an if numbers are passed by some fashion they are captured as well so the function just checks for the most +/// common true and false strings then uses stoll to convert the rest for summing +template ::value, detail::enabler> = detail::dummy> +void sum_flag_vector(const std::vector &flags, T &output) { + std::int64_t count{0}; + for(auto &flag : flags) { + count += detail::to_flag_value(flag); + } + output = (count > 0) ? static_cast(count) : T{0}; +} + +/// Sum a vector of flag representations +/// The flag vector produces a series of strings in a vector, simple true is represented by a "1", simple false is +/// by +/// "-1" an if numbers are passed by some fashion they are captured as well so the function just checks for the most +/// common true and false strings then uses stoll to convert the rest for summing +template ::value, detail::enabler> = detail::dummy> +void sum_flag_vector(const std::vector &flags, T &output) { + std::int64_t count{0}; + for(auto &flag : flags) { + count += detail::to_flag_value(flag); + } + output = static_cast(count); +} + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4800) +#endif +// with Atomic this could produce a warning due to the conversion but if atomic gets here it is an old style so will +// most likely still work + +/// Sum a vector of flag representations +/// The flag vector produces a series of strings in a vector, simple true is represented by a "1", simple false is +/// by +/// "-1" an if numbers are passed by some fashion they are captured as well so the function just checks for the most +/// common true and false strings then uses stoll to convert the rest for summing +template ::value && !std::is_unsigned::value, detail::enabler> = detail::dummy> +void sum_flag_vector(const std::vector &flags, T &output) { + std::int64_t count{0}; + for(auto &flag : flags) { + count += detail::to_flag_value(flag); + } + std::string out = detail::to_string(count); + lexical_cast(out, output); +} + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +} // namespace detail + + + +namespace detail { + +// Returns false if not a short option. Otherwise, sets opt name and rest and returns true +inline bool split_short(const std::string ¤t, std::string &name, std::string &rest) { + if(current.size() > 1 && current[0] == '-' && valid_first_char(current[1])) { + name = current.substr(1, 1); + rest = current.substr(2); + return true; + } + return false; +} + +// Returns false if not a long option. Otherwise, sets opt name and other side of = and returns true +inline bool split_long(const std::string ¤t, std::string &name, std::string &value) { + if(current.size() > 2 && current.substr(0, 2) == "--" && valid_first_char(current[2])) { + auto loc = current.find_first_of('='); + if(loc != std::string::npos) { + name = current.substr(2, loc - 2); + value = current.substr(loc + 1); + } else { + name = current.substr(2); + value = ""; + } + return true; + } + return false; +} + +// Returns false if not a windows style option. Otherwise, sets opt name and value and returns true +inline bool split_windows_style(const std::string ¤t, std::string &name, std::string &value) { + if(current.size() > 1 && current[0] == '/' && valid_first_char(current[1])) { + auto loc = current.find_first_of(':'); + if(loc != std::string::npos) { + name = current.substr(1, loc - 1); + value = current.substr(loc + 1); + } else { + name = current.substr(1); + value = ""; + } + return true; + } + return false; +} + +// Splits a string into multiple long and short names +inline std::vector split_names(std::string current) { + std::vector output; + std::size_t val; + while((val = current.find(",")) != std::string::npos) { + output.push_back(trim_copy(current.substr(0, val))); + current = current.substr(val + 1); + } + output.push_back(trim_copy(current)); + return output; +} + +/// extract default flag values either {def} or starting with a ! +inline std::vector> get_default_flag_values(const std::string &str) { + std::vector flags = split_names(str); + flags.erase(std::remove_if(flags.begin(), + flags.end(), + [](const std::string &name) { + return ((name.empty()) || (!(((name.find_first_of('{') != std::string::npos) && + (name.back() == '}')) || + (name[0] == '!')))); + }), + flags.end()); + std::vector> output; + output.reserve(flags.size()); + for(auto &flag : flags) { + auto def_start = flag.find_first_of('{'); + std::string defval = "false"; + if((def_start != std::string::npos) && (flag.back() == '}')) { + defval = flag.substr(def_start + 1); + defval.pop_back(); + flag.erase(def_start, std::string::npos); + } + flag.erase(0, flag.find_first_not_of("-!")); + output.emplace_back(flag, defval); + } + return output; +} + +/// Get a vector of short names, one of long names, and a single name +inline std::tuple, std::vector, std::string> +get_names(const std::vector &input) { + + std::vector short_names; + std::vector long_names; + std::string pos_name; + + for(std::string name : input) { + if(name.length() == 0) { + continue; + } + if(name.length() > 1 && name[0] == '-' && name[1] != '-') { + if(name.length() == 2 && valid_first_char(name[1])) + short_names.emplace_back(1, name[1]); + else + throw BadNameString::OneCharName(name); + } else if(name.length() > 2 && name.substr(0, 2) == "--") { + name = name.substr(2); + if(valid_name_string(name)) + long_names.push_back(name); + else + throw BadNameString::BadLongName(name); + } else if(name == "-" || name == "--") { + throw BadNameString::DashesOnly(name); + } else { + if(pos_name.length() > 0) + throw BadNameString::MultiPositionalNames(name); + pos_name = name; + } + } + + return std::tuple, std::vector, std::string>( + short_names, long_names, pos_name); +} + +} // namespace detail + + + +class App; + +/// Holds values to load into Options +struct ConfigItem { + /// This is the list of parents + std::vector parents{}; + + /// This is the name + std::string name{}; + + /// Listing of inputs + std::vector inputs{}; + + /// The list of parents and name joined by "." + std::string fullname() const { + std::vector tmp = parents; + tmp.emplace_back(name); + return detail::join(tmp, "."); + } +}; + +/// This class provides a converter for configuration files. +class Config { + protected: + std::vector items{}; + + public: + /// Convert an app into a configuration + virtual std::string to_config(const App *, bool, bool, std::string) const = 0; + + /// Convert a configuration into an app + virtual std::vector from_config(std::istream &) const = 0; + + /// Get a flag value + virtual std::string to_flag(const ConfigItem &item) const { + if(item.inputs.size() == 1) { + return item.inputs.at(0); + } + throw ConversionError::TooManyInputsFlag(item.fullname()); + } + + /// Parse a config file, throw an error (ParseError:ConfigParseError or FileError) on failure + std::vector from_file(const std::string &name) { + std::ifstream input{name}; + if(!input.good()) + throw FileError::Missing(name); + + return from_config(input); + } + + /// Virtual destructor + virtual ~Config() = default; +}; + +/// This converter works with INI/TOML files; to write INI files use ConfigINI +class ConfigBase : public Config { + protected: + /// the character used for comments + char commentChar = '#'; + /// the character used to start an array '\0' is a default to not use + char arrayStart = '['; + /// the character used to end an array '\0' is a default to not use + char arrayEnd = ']'; + /// the character used to separate elements in an array + char arraySeparator = ','; + /// the character used separate the name from the value + char valueDelimiter = '='; + /// the character to use around strings + char stringQuote = '"'; + /// the character to use around single characters + char characterQuote = '\''; + /// the maximum number of layers to allow + uint8_t maximumLayers{255}; + /// the separator used to separator parent layers + char parentSeparatorChar{'.'}; + /// Specify the configuration index to use for arrayed sections + int16_t configIndex{-1}; + /// Specify the configuration section that should be used + std::string configSection{}; + + public: + std::string + to_config(const App * /*app*/, bool default_also, bool write_description, std::string prefix) const override; + + std::vector from_config(std::istream &input) const override; + /// Specify the configuration for comment characters + ConfigBase *comment(char cchar) { + commentChar = cchar; + return this; + } + /// Specify the start and end characters for an array + ConfigBase *arrayBounds(char aStart, char aEnd) { + arrayStart = aStart; + arrayEnd = aEnd; + return this; + } + /// Specify the delimiter character for an array + ConfigBase *arrayDelimiter(char aSep) { + arraySeparator = aSep; + return this; + } + /// Specify the delimiter between a name and value + ConfigBase *valueSeparator(char vSep) { + valueDelimiter = vSep; + return this; + } + /// Specify the quote characters used around strings and characters + ConfigBase *quoteCharacter(char qString, char qChar) { + stringQuote = qString; + characterQuote = qChar; + return this; + } + /// Specify the maximum number of parents + ConfigBase *maxLayers(uint8_t layers) { + maximumLayers = layers; + return this; + } + /// Specify the separator to use for parent layers + ConfigBase *parentSeparator(char sep) { + parentSeparatorChar = sep; + return this; + } + /// get a reference to the configuration section + std::string §ionRef() { return configSection; } + /// get the section + const std::string §ion() const { return configSection; } + /// specify a particular section of the configuration file to use + ConfigBase *section(const std::string §ionName) { + configSection = sectionName; + return this; + } + + /// get a reference to the configuration index + int16_t &indexRef() { return configIndex; } + /// get the section index + int16_t index() const { return configIndex; } + /// specify a particular index in the section to use (-1) for all sections to use + ConfigBase *index(int16_t sectionIndex) { + configIndex = sectionIndex; + return this; + } +}; + +/// the default Config is the TOML file format +using ConfigTOML = ConfigBase; + +/// ConfigINI generates a "standard" INI compliant output +class ConfigINI : public ConfigTOML { + + public: + ConfigINI() { + commentChar = ';'; + arrayStart = '\0'; + arrayEnd = '\0'; + arraySeparator = ' '; + valueDelimiter = '='; + } +}; + + + +class Option; + +/// @defgroup validator_group Validators + +/// @brief Some validators that are provided +/// +/// These are simple `std::string(const std::string&)` validators that are useful. They return +/// a string if the validation fails. A custom struct is provided, as well, with the same user +/// semantics, but with the ability to provide a new type name. +/// @{ + +/// +class Validator { + protected: + /// This is the description function, if empty the description_ will be used + std::function desc_function_{[]() { return std::string{}; }}; + + /// This is the base function that is to be called. + /// Returns a string error message if validation fails. + std::function func_{[](std::string &) { return std::string{}; }}; + /// The name for search purposes of the Validator + std::string name_{}; + /// A Validator will only apply to an indexed value (-1 is all elements) + int application_index_ = -1; + /// Enable for Validator to allow it to be disabled if need be + bool active_{true}; + /// specify that a validator should not modify the input + bool non_modifying_{false}; + + public: + Validator() = default; + /// Construct a Validator with just the description string + explicit Validator(std::string validator_desc) : desc_function_([validator_desc]() { return validator_desc; }) {} + /// Construct Validator from basic information + Validator(std::function op, std::string validator_desc, std::string validator_name = "") + : desc_function_([validator_desc]() { return validator_desc; }), func_(std::move(op)), + name_(std::move(validator_name)) {} + /// Set the Validator operation function + Validator &operation(std::function op) { + func_ = std::move(op); + return *this; + } + /// This is the required operator for a Validator - provided to help + /// users (CLI11 uses the member `func` directly) + std::string operator()(std::string &str) const { + std::string retstring; + if(active_) { + if(non_modifying_) { + std::string value = str; + retstring = func_(value); + } else { + retstring = func_(str); + } + } + return retstring; + } + + /// This is the required operator for a Validator - provided to help + /// users (CLI11 uses the member `func` directly) + std::string operator()(const std::string &str) const { + std::string value = str; + return (active_) ? func_(value) : std::string{}; + } + + /// Specify the type string + Validator &description(std::string validator_desc) { + desc_function_ = [validator_desc]() { return validator_desc; }; + return *this; + } + /// Specify the type string + Validator description(std::string validator_desc) const { + Validator newval(*this); + newval.desc_function_ = [validator_desc]() { return validator_desc; }; + return newval; + } + /// Generate type description information for the Validator + std::string get_description() const { + if(active_) { + return desc_function_(); + } + return std::string{}; + } + /// Specify the type string + Validator &name(std::string validator_name) { + name_ = std::move(validator_name); + return *this; + } + /// Specify the type string + Validator name(std::string validator_name) const { + Validator newval(*this); + newval.name_ = std::move(validator_name); + return newval; + } + /// Get the name of the Validator + const std::string &get_name() const { return name_; } + /// Specify whether the Validator is active or not + Validator &active(bool active_val = true) { + active_ = active_val; + return *this; + } + /// Specify whether the Validator is active or not + Validator active(bool active_val = true) const { + Validator newval(*this); + newval.active_ = active_val; + return newval; + } + + /// Specify whether the Validator can be modifying or not + Validator &non_modifying(bool no_modify = true) { + non_modifying_ = no_modify; + return *this; + } + /// Specify the application index of a validator + Validator &application_index(int app_index) { + application_index_ = app_index; + return *this; + } + /// Specify the application index of a validator + Validator application_index(int app_index) const { + Validator newval(*this); + newval.application_index_ = app_index; + return newval; + } + /// Get the current value of the application index + int get_application_index() const { return application_index_; } + /// Get a boolean if the validator is active + bool get_active() const { return active_; } + + /// Get a boolean if the validator is allowed to modify the input returns true if it can modify the input + bool get_modifying() const { return !non_modifying_; } + + /// Combining validators is a new validator. Type comes from left validator if function, otherwise only set if the + /// same. + Validator operator&(const Validator &other) const { + Validator newval; + + newval._merge_description(*this, other, " AND "); + + // Give references (will make a copy in lambda function) + const std::function &f1 = func_; + const std::function &f2 = other.func_; + + newval.func_ = [f1, f2](std::string &input) { + std::string s1 = f1(input); + std::string s2 = f2(input); + if(!s1.empty() && !s2.empty()) + return std::string("(") + s1 + ") AND (" + s2 + ")"; + else + return s1 + s2; + }; + + newval.active_ = (active_ & other.active_); + newval.application_index_ = application_index_; + return newval; + } + + /// Combining validators is a new validator. Type comes from left validator if function, otherwise only set if the + /// same. + Validator operator|(const Validator &other) const { + Validator newval; + + newval._merge_description(*this, other, " OR "); + + // Give references (will make a copy in lambda function) + const std::function &f1 = func_; + const std::function &f2 = other.func_; + + newval.func_ = [f1, f2](std::string &input) { + std::string s1 = f1(input); + std::string s2 = f2(input); + if(s1.empty() || s2.empty()) + return std::string(); + + return std::string("(") + s1 + ") OR (" + s2 + ")"; + }; + newval.active_ = (active_ & other.active_); + newval.application_index_ = application_index_; + return newval; + } + + /// Create a validator that fails when a given validator succeeds + Validator operator!() const { + Validator newval; + const std::function &dfunc1 = desc_function_; + newval.desc_function_ = [dfunc1]() { + auto str = dfunc1(); + return (!str.empty()) ? std::string("NOT ") + str : std::string{}; + }; + // Give references (will make a copy in lambda function) + const std::function &f1 = func_; + + newval.func_ = [f1, dfunc1](std::string &test) -> std::string { + std::string s1 = f1(test); + if(s1.empty()) { + return std::string("check ") + dfunc1() + " succeeded improperly"; + } + return std::string{}; + }; + newval.active_ = active_; + newval.application_index_ = application_index_; + return newval; + } + + private: + void _merge_description(const Validator &val1, const Validator &val2, const std::string &merger) { + + const std::function &dfunc1 = val1.desc_function_; + const std::function &dfunc2 = val2.desc_function_; + + desc_function_ = [=]() { + std::string f1 = dfunc1(); + std::string f2 = dfunc2(); + if((f1.empty()) || (f2.empty())) { + return f1 + f2; + } + return std::string(1, '(') + f1 + ')' + merger + '(' + f2 + ')'; + }; + } +}; // namespace CLI + +/// Class wrapping some of the accessors of Validator +class CustomValidator : public Validator { + public: +}; +// The implementation of the built in validators is using the Validator class; +// the user is only expected to use the const (static) versions (since there's no setup). +// Therefore, this is in detail. +namespace detail { + +/// CLI enumeration of different file types +enum class path_type { nonexistent, file, directory }; + +#if defined CLI11_HAS_FILESYSTEM && CLI11_HAS_FILESYSTEM > 0 +/// get the type of the path from a file name +inline path_type check_path(const char *file) noexcept { + std::error_code ec; + auto stat = std::filesystem::status(file, ec); + if(ec) { + return path_type::nonexistent; + } + switch(stat.type()) { + case std::filesystem::file_type::none: + case std::filesystem::file_type::not_found: + return path_type::nonexistent; + case std::filesystem::file_type::directory: + return path_type::directory; + case std::filesystem::file_type::symlink: + case std::filesystem::file_type::block: + case std::filesystem::file_type::character: + case std::filesystem::file_type::fifo: + case std::filesystem::file_type::socket: + case std::filesystem::file_type::regular: + case std::filesystem::file_type::unknown: + default: + return path_type::file; + } +} +#else +/// get the type of the path from a file name +inline path_type check_path(const char *file) noexcept { +#if defined(_MSC_VER) + struct __stat64 buffer; + if(_stat64(file, &buffer) == 0) { + return ((buffer.st_mode & S_IFDIR) != 0) ? path_type::directory : path_type::file; + } +#else + struct stat buffer; + if(stat(file, &buffer) == 0) { + return ((buffer.st_mode & S_IFDIR) != 0) ? path_type::directory : path_type::file; + } +#endif + return path_type::nonexistent; +} +#endif +/// Check for an existing file (returns error message if check fails) +class ExistingFileValidator : public Validator { + public: + ExistingFileValidator() : Validator("FILE") { + func_ = [](std::string &filename) { + auto path_result = check_path(filename.c_str()); + if(path_result == path_type::nonexistent) { + return "File does not exist: " + filename; + } + if(path_result == path_type::directory) { + return "File is actually a directory: " + filename; + } + return std::string(); + }; + } +}; + +/// Check for an existing directory (returns error message if check fails) +class ExistingDirectoryValidator : public Validator { + public: + ExistingDirectoryValidator() : Validator("DIR") { + func_ = [](std::string &filename) { + auto path_result = check_path(filename.c_str()); + if(path_result == path_type::nonexistent) { + return "Directory does not exist: " + filename; + } + if(path_result == path_type::file) { + return "Directory is actually a file: " + filename; + } + return std::string(); + }; + } +}; + +/// Check for an existing path +class ExistingPathValidator : public Validator { + public: + ExistingPathValidator() : Validator("PATH(existing)") { + func_ = [](std::string &filename) { + auto path_result = check_path(filename.c_str()); + if(path_result == path_type::nonexistent) { + return "Path does not exist: " + filename; + } + return std::string(); + }; + } +}; + +/// Check for an non-existing path +class NonexistentPathValidator : public Validator { + public: + NonexistentPathValidator() : Validator("PATH(non-existing)") { + func_ = [](std::string &filename) { + auto path_result = check_path(filename.c_str()); + if(path_result != path_type::nonexistent) { + return "Path already exists: " + filename; + } + return std::string(); + }; + } +}; + +/// Validate the given string is a legal ipv4 address +class IPV4Validator : public Validator { + public: + IPV4Validator() : Validator("IPV4") { + func_ = [](std::string &ip_addr) { + auto result = CLI::detail::split(ip_addr, '.'); + if(result.size() != 4) { + return std::string("Invalid IPV4 address must have four parts (") + ip_addr + ')'; + } + int num; + for(const auto &var : result) { + bool retval = detail::lexical_cast(var, num); + if(!retval) { + return std::string("Failed parsing number (") + var + ')'; + } + if(num < 0 || num > 255) { + return std::string("Each IP number must be between 0 and 255 ") + var; + } + } + return std::string(); + }; + } +}; + +} // namespace detail + +// Static is not needed here, because global const implies static. + +/// Check for existing file (returns error message if check fails) +const detail::ExistingFileValidator ExistingFile; + +/// Check for an existing directory (returns error message if check fails) +const detail::ExistingDirectoryValidator ExistingDirectory; + +/// Check for an existing path +const detail::ExistingPathValidator ExistingPath; + +/// Check for an non-existing path +const detail::NonexistentPathValidator NonexistentPath; + +/// Check for an IP4 address +const detail::IPV4Validator ValidIPV4; + +/// Validate the input as a particular type +template class TypeValidator : public Validator { + public: + explicit TypeValidator(const std::string &validator_name) : Validator(validator_name) { + func_ = [](std::string &input_string) { + auto val = DesiredType(); + if(!detail::lexical_cast(input_string, val)) { + return std::string("Failed parsing ") + input_string + " as a " + detail::type_name(); + } + return std::string(); + }; + } + TypeValidator() : TypeValidator(detail::type_name()) {} +}; + +/// Check for a number +const TypeValidator Number("NUMBER"); + +/// Produce a range (factory). Min and max are inclusive. +class Range : public Validator { + public: + /// This produces a range with min and max inclusive. + /// + /// Note that the constructor is templated, but the struct is not, so C++17 is not + /// needed to provide nice syntax for Range(a,b). + template + Range(T min_val, T max_val, const std::string &validator_name = std::string{}) : Validator(validator_name) { + if(validator_name.empty()) { + std::stringstream out; + out << detail::type_name() << " in [" << min_val << " - " << max_val << "]"; + description(out.str()); + } + + func_ = [min_val, max_val](std::string &input) { + T val; + bool converted = detail::lexical_cast(input, val); + if((!converted) || (val < min_val || val > max_val)) + return std::string("Value ") + input + " not in range " + std::to_string(min_val) + " to " + + std::to_string(max_val); + + return std::string{}; + }; + } + + /// Range of one value is 0 to value + template + explicit Range(T max_val, const std::string &validator_name = std::string{}) + : Range(static_cast(0), max_val, validator_name) {} +}; + +/// Check for a non negative number +const Range NonNegativeNumber((std::numeric_limits::max)(), "NONNEGATIVE"); + +/// Check for a positive valued number (val>0.0), min() her is the smallest positive number +const Range PositiveNumber((std::numeric_limits::min)(), (std::numeric_limits::max)(), "POSITIVE"); + +/// Produce a bounded range (factory). Min and max are inclusive. +class Bound : public Validator { + public: + /// This bounds a value with min and max inclusive. + /// + /// Note that the constructor is templated, but the struct is not, so C++17 is not + /// needed to provide nice syntax for Range(a,b). + template Bound(T min_val, T max_val) { + std::stringstream out; + out << detail::type_name() << " bounded to [" << min_val << " - " << max_val << "]"; + description(out.str()); + + func_ = [min_val, max_val](std::string &input) { + T val; + bool converted = detail::lexical_cast(input, val); + if(!converted) { + return std::string("Value ") + input + " could not be converted"; + } + if(val < min_val) + input = detail::to_string(min_val); + else if(val > max_val) + input = detail::to_string(max_val); + + return std::string{}; + }; + } + + /// Range of one value is 0 to value + template explicit Bound(T max_val) : Bound(static_cast(0), max_val) {} +}; + +namespace detail { +template ::type>::value, detail::enabler> = detail::dummy> +auto smart_deref(T value) -> decltype(*value) { + return *value; +} + +template < + typename T, + enable_if_t::type>::value, detail::enabler> = detail::dummy> +typename std::remove_reference::type &smart_deref(T &value) { + return value; +} +/// Generate a string representation of a set +template std::string generate_set(const T &set) { + using element_t = typename detail::element_type::type; + using iteration_type_t = typename detail::pair_adaptor::value_type; // the type of the object pair + std::string out(1, '{'); + out.append(detail::join( + detail::smart_deref(set), + [](const iteration_type_t &v) { return detail::pair_adaptor::first(v); }, + ",")); + out.push_back('}'); + return out; +} + +/// Generate a string representation of a map +template std::string generate_map(const T &map, bool key_only = false) { + using element_t = typename detail::element_type::type; + using iteration_type_t = typename detail::pair_adaptor::value_type; // the type of the object pair + std::string out(1, '{'); + out.append(detail::join( + detail::smart_deref(map), + [key_only](const iteration_type_t &v) { + std::string res{detail::to_string(detail::pair_adaptor::first(v))}; + + if(!key_only) { + res.append("->"); + res += detail::to_string(detail::pair_adaptor::second(v)); + } + return res; + }, + ",")); + out.push_back('}'); + return out; +} + +template struct has_find { + template + static auto test(int) -> decltype(std::declval().find(std::declval()), std::true_type()); + template static auto test(...) -> decltype(std::false_type()); + + static const auto value = decltype(test(0))::value; + using type = std::integral_constant; +}; + +/// A search function +template ::value, detail::enabler> = detail::dummy> +auto search(const T &set, const V &val) -> std::pair { + using element_t = typename detail::element_type::type; + auto &setref = detail::smart_deref(set); + auto it = std::find_if(std::begin(setref), std::end(setref), [&val](decltype(*std::begin(setref)) v) { + return (detail::pair_adaptor::first(v) == val); + }); + return {(it != std::end(setref)), it}; +} + +/// A search function that uses the built in find function +template ::value, detail::enabler> = detail::dummy> +auto search(const T &set, const V &val) -> std::pair { + auto &setref = detail::smart_deref(set); + auto it = setref.find(val); + return {(it != std::end(setref)), it}; +} + +/// A search function with a filter function +template +auto search(const T &set, const V &val, const std::function &filter_function) + -> std::pair { + using element_t = typename detail::element_type::type; + // do the potentially faster first search + auto res = search(set, val); + if((res.first) || (!(filter_function))) { + return res; + } + // if we haven't found it do the longer linear search with all the element translations + auto &setref = detail::smart_deref(set); + auto it = std::find_if(std::begin(setref), std::end(setref), [&](decltype(*std::begin(setref)) v) { + V a{detail::pair_adaptor::first(v)}; + a = filter_function(a); + return (a == val); + }); + return {(it != std::end(setref)), it}; +} + +// the following suggestion was made by Nikita Ofitserov(@himikof) +// done in templates to prevent compiler warnings on negation of unsigned numbers + +/// Do a check for overflow on signed numbers +template +inline typename std::enable_if::value, T>::type overflowCheck(const T &a, const T &b) { + if((a > 0) == (b > 0)) { + return ((std::numeric_limits::max)() / (std::abs)(a) < (std::abs)(b)); + } else { + return ((std::numeric_limits::min)() / (std::abs)(a) > -(std::abs)(b)); + } +} +/// Do a check for overflow on unsigned numbers +template +inline typename std::enable_if::value, T>::type overflowCheck(const T &a, const T &b) { + return ((std::numeric_limits::max)() / a < b); +} + +/// Performs a *= b; if it doesn't cause integer overflow. Returns false otherwise. +template typename std::enable_if::value, bool>::type checked_multiply(T &a, T b) { + if(a == 0 || b == 0 || a == 1 || b == 1) { + a *= b; + return true; + } + if(a == (std::numeric_limits::min)() || b == (std::numeric_limits::min)()) { + return false; + } + if(overflowCheck(a, b)) { + return false; + } + a *= b; + return true; +} + +/// Performs a *= b; if it doesn't equal infinity. Returns false otherwise. +template +typename std::enable_if::value, bool>::type checked_multiply(T &a, T b) { + T c = a * b; + if(std::isinf(c) && !std::isinf(a) && !std::isinf(b)) { + return false; + } + a = c; + return true; +} + +} // namespace detail +/// Verify items are in a set +class IsMember : public Validator { + public: + using filter_fn_t = std::function; + + /// This allows in-place construction using an initializer list + template + IsMember(std::initializer_list values, Args &&...args) + : IsMember(std::vector(values), std::forward(args)...) {} + + /// This checks to see if an item is in a set (empty function) + template explicit IsMember(T &&set) : IsMember(std::forward(set), nullptr) {} + + /// This checks to see if an item is in a set: pointer or copy version. You can pass in a function that will filter + /// both sides of the comparison before computing the comparison. + template explicit IsMember(T set, F filter_function) { + + // Get the type of the contained item - requires a container have ::value_type + // if the type does not have first_type and second_type, these are both value_type + using element_t = typename detail::element_type::type; // Removes (smart) pointers if needed + using item_t = typename detail::pair_adaptor::first_type; // Is value_type if not a map + + using local_item_t = typename IsMemberType::type; // This will convert bad types to good ones + // (const char * to std::string) + + // Make a local copy of the filter function, using a std::function if not one already + std::function filter_fn = filter_function; + + // This is the type name for help, it will take the current version of the set contents + desc_function_ = [set]() { return detail::generate_set(detail::smart_deref(set)); }; + + // This is the function that validates + // It stores a copy of the set pointer-like, so shared_ptr will stay alive + func_ = [set, filter_fn](std::string &input) { + local_item_t b; + if(!detail::lexical_cast(input, b)) { + throw ValidationError(input); // name is added later + } + if(filter_fn) { + b = filter_fn(b); + } + auto res = detail::search(set, b, filter_fn); + if(res.first) { + // Make sure the version in the input string is identical to the one in the set + if(filter_fn) { + input = detail::value_string(detail::pair_adaptor::first(*(res.second))); + } + + // Return empty error string (success) + return std::string{}; + } + + // If you reach this point, the result was not found + return input + " not in " + detail::generate_set(detail::smart_deref(set)); + }; + } + + /// You can pass in as many filter functions as you like, they nest (string only currently) + template + IsMember(T &&set, filter_fn_t filter_fn_1, filter_fn_t filter_fn_2, Args &&...other) + : IsMember( + std::forward(set), + [filter_fn_1, filter_fn_2](std::string a) { return filter_fn_2(filter_fn_1(a)); }, + other...) {} +}; + +/// definition of the default transformation object +template using TransformPairs = std::vector>; + +/// Translate named items to other or a value set +class Transformer : public Validator { + public: + using filter_fn_t = std::function; + + /// This allows in-place construction + template + Transformer(std::initializer_list> values, Args &&...args) + : Transformer(TransformPairs(values), std::forward(args)...) {} + + /// direct map of std::string to std::string + template explicit Transformer(T &&mapping) : Transformer(std::forward(mapping), nullptr) {} + + /// This checks to see if an item is in a set: pointer or copy version. You can pass in a function that will filter + /// both sides of the comparison before computing the comparison. + template explicit Transformer(T mapping, F filter_function) { + + static_assert(detail::pair_adaptor::type>::value, + "mapping must produce value pairs"); + // Get the type of the contained item - requires a container have ::value_type + // if the type does not have first_type and second_type, these are both value_type + using element_t = typename detail::element_type::type; // Removes (smart) pointers if needed + using item_t = typename detail::pair_adaptor::first_type; // Is value_type if not a map + using local_item_t = typename IsMemberType::type; // Will convert bad types to good ones + // (const char * to std::string) + + // Make a local copy of the filter function, using a std::function if not one already + std::function filter_fn = filter_function; + + // This is the type name for help, it will take the current version of the set contents + desc_function_ = [mapping]() { return detail::generate_map(detail::smart_deref(mapping)); }; + + func_ = [mapping, filter_fn](std::string &input) { + local_item_t b; + if(!detail::lexical_cast(input, b)) { + return std::string(); + // there is no possible way we can match anything in the mapping if we can't convert so just return + } + if(filter_fn) { + b = filter_fn(b); + } + auto res = detail::search(mapping, b, filter_fn); + if(res.first) { + input = detail::value_string(detail::pair_adaptor::second(*res.second)); + } + return std::string{}; + }; + } + + /// You can pass in as many filter functions as you like, they nest + template + Transformer(T &&mapping, filter_fn_t filter_fn_1, filter_fn_t filter_fn_2, Args &&...other) + : Transformer( + std::forward(mapping), + [filter_fn_1, filter_fn_2](std::string a) { return filter_fn_2(filter_fn_1(a)); }, + other...) {} +}; + +/// translate named items to other or a value set +class CheckedTransformer : public Validator { + public: + using filter_fn_t = std::function; + + /// This allows in-place construction + template + CheckedTransformer(std::initializer_list> values, Args &&...args) + : CheckedTransformer(TransformPairs(values), std::forward(args)...) {} + + /// direct map of std::string to std::string + template explicit CheckedTransformer(T mapping) : CheckedTransformer(std::move(mapping), nullptr) {} + + /// This checks to see if an item is in a set: pointer or copy version. You can pass in a function that will filter + /// both sides of the comparison before computing the comparison. + template explicit CheckedTransformer(T mapping, F filter_function) { + + static_assert(detail::pair_adaptor::type>::value, + "mapping must produce value pairs"); + // Get the type of the contained item - requires a container have ::value_type + // if the type does not have first_type and second_type, these are both value_type + using element_t = typename detail::element_type::type; // Removes (smart) pointers if needed + using item_t = typename detail::pair_adaptor::first_type; // Is value_type if not a map + using local_item_t = typename IsMemberType::type; // Will convert bad types to good ones + // (const char * to std::string) + using iteration_type_t = typename detail::pair_adaptor::value_type; // the type of the object pair + + // Make a local copy of the filter function, using a std::function if not one already + std::function filter_fn = filter_function; + + auto tfunc = [mapping]() { + std::string out("value in "); + out += detail::generate_map(detail::smart_deref(mapping)) + " OR {"; + out += detail::join( + detail::smart_deref(mapping), + [](const iteration_type_t &v) { return detail::to_string(detail::pair_adaptor::second(v)); }, + ","); + out.push_back('}'); + return out; + }; + + desc_function_ = tfunc; + + func_ = [mapping, tfunc, filter_fn](std::string &input) { + local_item_t b; + bool converted = detail::lexical_cast(input, b); + if(converted) { + if(filter_fn) { + b = filter_fn(b); + } + auto res = detail::search(mapping, b, filter_fn); + if(res.first) { + input = detail::value_string(detail::pair_adaptor::second(*res.second)); + return std::string{}; + } + } + for(const auto &v : detail::smart_deref(mapping)) { + auto output_string = detail::value_string(detail::pair_adaptor::second(v)); + if(output_string == input) { + return std::string(); + } + } + + return "Check " + input + " " + tfunc() + " FAILED"; + }; + } + + /// You can pass in as many filter functions as you like, they nest + template + CheckedTransformer(T &&mapping, filter_fn_t filter_fn_1, filter_fn_t filter_fn_2, Args &&...other) + : CheckedTransformer( + std::forward(mapping), + [filter_fn_1, filter_fn_2](std::string a) { return filter_fn_2(filter_fn_1(a)); }, + other...) {} +}; + +/// Helper function to allow ignore_case to be passed to IsMember or Transform +inline std::string ignore_case(std::string item) { return detail::to_lower(item); } + +/// Helper function to allow ignore_underscore to be passed to IsMember or Transform +inline std::string ignore_underscore(std::string item) { return detail::remove_underscore(item); } + +/// Helper function to allow checks to ignore spaces to be passed to IsMember or Transform +inline std::string ignore_space(std::string item) { + item.erase(std::remove(std::begin(item), std::end(item), ' '), std::end(item)); + item.erase(std::remove(std::begin(item), std::end(item), '\t'), std::end(item)); + return item; +} + +/// Multiply a number by a factor using given mapping. +/// Can be used to write transforms for SIZE or DURATION inputs. +/// +/// Example: +/// With mapping = `{"b"->1, "kb"->1024, "mb"->1024*1024}` +/// one can recognize inputs like "100", "12kb", "100 MB", +/// that will be automatically transformed to 100, 14448, 104857600. +/// +/// Output number type matches the type in the provided mapping. +/// Therefore, if it is required to interpret real inputs like "0.42 s", +/// the mapping should be of a type or . +class AsNumberWithUnit : public Validator { + public: + /// Adjust AsNumberWithUnit behavior. + /// CASE_SENSITIVE/CASE_INSENSITIVE controls how units are matched. + /// UNIT_OPTIONAL/UNIT_REQUIRED throws ValidationError + /// if UNIT_REQUIRED is set and unit literal is not found. + enum Options { + CASE_SENSITIVE = 0, + CASE_INSENSITIVE = 1, + UNIT_OPTIONAL = 0, + UNIT_REQUIRED = 2, + DEFAULT = CASE_INSENSITIVE | UNIT_OPTIONAL + }; + + template + explicit AsNumberWithUnit(std::map mapping, + Options opts = DEFAULT, + const std::string &unit_name = "UNIT") { + description(generate_description(unit_name, opts)); + validate_mapping(mapping, opts); + + // transform function + func_ = [mapping, opts](std::string &input) -> std::string { + Number num; + + detail::rtrim(input); + if(input.empty()) { + throw ValidationError("Input is empty"); + } + + // Find split position between number and prefix + auto unit_begin = input.end(); + while(unit_begin > input.begin() && std::isalpha(*(unit_begin - 1), std::locale())) { + --unit_begin; + } + + std::string unit{unit_begin, input.end()}; + input.resize(static_cast(std::distance(input.begin(), unit_begin))); + detail::trim(input); + + if(opts & UNIT_REQUIRED && unit.empty()) { + throw ValidationError("Missing mandatory unit"); + } + if(opts & CASE_INSENSITIVE) { + unit = detail::to_lower(unit); + } + if(unit.empty()) { + if(!detail::lexical_cast(input, num)) { + throw ValidationError(std::string("Value ") + input + " could not be converted to " + + detail::type_name()); + } + // No need to modify input if no unit passed + return {}; + } + + // find corresponding factor + auto it = mapping.find(unit); + if(it == mapping.end()) { + throw ValidationError(unit + + " unit not recognized. " + "Allowed values: " + + detail::generate_map(mapping, true)); + } + + if(!input.empty()) { + bool converted = detail::lexical_cast(input, num); + if(!converted) { + throw ValidationError(std::string("Value ") + input + " could not be converted to " + + detail::type_name()); + } + // perform safe multiplication + bool ok = detail::checked_multiply(num, it->second); + if(!ok) { + throw ValidationError(detail::to_string(num) + " multiplied by " + unit + + " factor would cause number overflow. Use smaller value."); + } + } else { + num = static_cast(it->second); + } + + input = detail::to_string(num); + + return {}; + }; + } + + private: + /// Check that mapping contains valid units. + /// Update mapping for CASE_INSENSITIVE mode. + template static void validate_mapping(std::map &mapping, Options opts) { + for(auto &kv : mapping) { + if(kv.first.empty()) { + throw ValidationError("Unit must not be empty."); + } + if(!detail::isalpha(kv.first)) { + throw ValidationError("Unit must contain only letters."); + } + } + + // make all units lowercase if CASE_INSENSITIVE + if(opts & CASE_INSENSITIVE) { + std::map lower_mapping; + for(auto &kv : mapping) { + auto s = detail::to_lower(kv.first); + if(lower_mapping.count(s)) { + throw ValidationError(std::string("Several matching lowercase unit representations are found: ") + + s); + } + lower_mapping[detail::to_lower(kv.first)] = kv.second; + } + mapping = std::move(lower_mapping); + } + } + + /// Generate description like this: NUMBER [UNIT] + template static std::string generate_description(const std::string &name, Options opts) { + std::stringstream out; + out << detail::type_name() << ' '; + if(opts & UNIT_REQUIRED) { + out << name; + } else { + out << '[' << name << ']'; + } + return out.str(); + } +}; + +/// Converts a human-readable size string (with unit literal) to uin64_t size. +/// Example: +/// "100" => 100 +/// "1 b" => 100 +/// "10Kb" => 10240 // you can configure this to be interpreted as kilobyte (*1000) or kibibyte (*1024) +/// "10 KB" => 10240 +/// "10 kb" => 10240 +/// "10 kib" => 10240 // *i, *ib are always interpreted as *bibyte (*1024) +/// "10kb" => 10240 +/// "2 MB" => 2097152 +/// "2 EiB" => 2^61 // Units up to exibyte are supported +class AsSizeValue : public AsNumberWithUnit { + public: + using result_t = std::uint64_t; + + /// If kb_is_1000 is true, + /// interpret 'kb', 'k' as 1000 and 'kib', 'ki' as 1024 + /// (same applies to higher order units as well). + /// Otherwise, interpret all literals as factors of 1024. + /// The first option is formally correct, but + /// the second interpretation is more wide-spread + /// (see https://en.wikipedia.org/wiki/Binary_prefix). + explicit AsSizeValue(bool kb_is_1000) : AsNumberWithUnit(get_mapping(kb_is_1000)) { + if(kb_is_1000) { + description("SIZE [b, kb(=1000b), kib(=1024b), ...]"); + } else { + description("SIZE [b, kb(=1024b), ...]"); + } + } + + private: + /// Get mapping + static std::map init_mapping(bool kb_is_1000) { + std::map m; + result_t k_factor = kb_is_1000 ? 1000 : 1024; + result_t ki_factor = 1024; + result_t k = 1; + result_t ki = 1; + m["b"] = 1; + for(std::string p : {"k", "m", "g", "t", "p", "e"}) { + k *= k_factor; + ki *= ki_factor; + m[p] = k; + m[p + "b"] = k; + m[p + "i"] = ki; + m[p + "ib"] = ki; + } + return m; + } + + /// Cache calculated mapping + static std::map get_mapping(bool kb_is_1000) { + if(kb_is_1000) { + static auto m = init_mapping(true); + return m; + } else { + static auto m = init_mapping(false); + return m; + } + } +}; + +namespace detail { +/// Split a string into a program name and command line arguments +/// the string is assumed to contain a file name followed by other arguments +/// the return value contains is a pair with the first argument containing the program name and the second +/// everything else. +inline std::pair split_program_name(std::string commandline) { + // try to determine the programName + std::pair vals; + trim(commandline); + auto esp = commandline.find_first_of(' ', 1); + while(detail::check_path(commandline.substr(0, esp).c_str()) != path_type::file) { + esp = commandline.find_first_of(' ', esp + 1); + if(esp == std::string::npos) { + // if we have reached the end and haven't found a valid file just assume the first argument is the + // program name + if(commandline[0] == '"' || commandline[0] == '\'' || commandline[0] == '`') { + bool embeddedQuote = false; + auto keyChar = commandline[0]; + auto end = commandline.find_first_of(keyChar, 1); + while((end != std::string::npos) && (commandline[end - 1] == '\\')) { // deal with escaped quotes + end = commandline.find_first_of(keyChar, end + 1); + embeddedQuote = true; + } + if(end != std::string::npos) { + vals.first = commandline.substr(1, end - 1); + esp = end + 1; + if(embeddedQuote) { + vals.first = find_and_replace(vals.first, std::string("\\") + keyChar, std::string(1, keyChar)); + } + } else { + esp = commandline.find_first_of(' ', 1); + } + } else { + esp = commandline.find_first_of(' ', 1); + } + + break; + } + } + if(vals.first.empty()) { + vals.first = commandline.substr(0, esp); + rtrim(vals.first); + } + + // strip the program name + vals.second = (esp != std::string::npos) ? commandline.substr(esp + 1) : std::string{}; + ltrim(vals.second); + return vals; +} + +} // namespace detail +/// @} + + + + +class Option; +class App; + +/// This enum signifies the type of help requested +/// +/// This is passed in by App; all user classes must accept this as +/// the second argument. + +enum class AppFormatMode { + Normal, ///< The normal, detailed help + All, ///< A fully expanded help + Sub, ///< Used when printed as part of expanded subcommand +}; + +/// This is the minimum requirements to run a formatter. +/// +/// A user can subclass this is if they do not care at all +/// about the structure in CLI::Formatter. +class FormatterBase { + protected: + /// @name Options + ///@{ + + /// The width of the first column + std::size_t column_width_{30}; + + /// @brief The required help printout labels (user changeable) + /// Values are Needs, Excludes, etc. + std::map labels_{}; + + ///@} + /// @name Basic + ///@{ + + public: + FormatterBase() = default; + FormatterBase(const FormatterBase &) = default; + FormatterBase(FormatterBase &&) = default; + + /// Adding a destructor in this form to work around bug in GCC 4.7 + virtual ~FormatterBase() noexcept {} // NOLINT(modernize-use-equals-default) + + /// This is the key method that puts together help + virtual std::string make_help(const App *, std::string, AppFormatMode) const = 0; + + ///@} + /// @name Setters + ///@{ + + /// Set the "REQUIRED" label + void label(std::string key, std::string val) { labels_[key] = val; } + + /// Set the column width + void column_width(std::size_t val) { column_width_ = val; } + + ///@} + /// @name Getters + ///@{ + + /// Get the current value of a name (REQUIRED, etc.) + std::string get_label(std::string key) const { + if(labels_.find(key) == labels_.end()) + return key; + else + return labels_.at(key); + } + + /// Get the current column width + std::size_t get_column_width() const { return column_width_; } + + ///@} +}; + +/// This is a specialty override for lambda functions +class FormatterLambda final : public FormatterBase { + using funct_t = std::function; + + /// The lambda to hold and run + funct_t lambda_; + + public: + /// Create a FormatterLambda with a lambda function + explicit FormatterLambda(funct_t funct) : lambda_(std::move(funct)) {} + + /// Adding a destructor (mostly to make GCC 4.7 happy) + ~FormatterLambda() noexcept override {} // NOLINT(modernize-use-equals-default) + + /// This will simply call the lambda function + std::string make_help(const App *app, std::string name, AppFormatMode mode) const override { + return lambda_(app, name, mode); + } +}; + +/// This is the default Formatter for CLI11. It pretty prints help output, and is broken into quite a few +/// overridable methods, to be highly customizable with minimal effort. +class Formatter : public FormatterBase { + public: + Formatter() = default; + Formatter(const Formatter &) = default; + Formatter(Formatter &&) = default; + + /// @name Overridables + ///@{ + + /// This prints out a group of options with title + /// + virtual std::string make_group(std::string group, bool is_positional, std::vector opts) const; + + /// This prints out just the positionals "group" + virtual std::string make_positionals(const App *app) const; + + /// This prints out all the groups of options + std::string make_groups(const App *app, AppFormatMode mode) const; + + /// This prints out all the subcommands + virtual std::string make_subcommands(const App *app, AppFormatMode mode) const; + + /// This prints out a subcommand + virtual std::string make_subcommand(const App *sub) const; + + /// This prints out a subcommand in help-all + virtual std::string make_expanded(const App *sub) const; + + /// This prints out all the groups of options + virtual std::string make_footer(const App *app) const; + + /// This displays the description line + virtual std::string make_description(const App *app) const; + + /// This displays the usage line + virtual std::string make_usage(const App *app, std::string name) const; + + /// This puts everything together + std::string make_help(const App * /*app*/, std::string, AppFormatMode) const override; + + ///@} + /// @name Options + ///@{ + + /// This prints out an option help line, either positional or optional form + virtual std::string make_option(const Option *opt, bool is_positional) const { + std::stringstream out; + detail::format_help( + out, make_option_name(opt, is_positional) + make_option_opts(opt), make_option_desc(opt), column_width_); + return out.str(); + } + + /// @brief This is the name part of an option, Default: left column + virtual std::string make_option_name(const Option *, bool) const; + + /// @brief This is the options part of the name, Default: combined into left column + virtual std::string make_option_opts(const Option *) const; + + /// @brief This is the description. Default: Right column, on new line if left column too large + virtual std::string make_option_desc(const Option *) const; + + /// @brief This is used to print the name on the USAGE line + virtual std::string make_option_usage(const Option *opt) const; + + ///@} +}; + + + + +using results_t = std::vector; +/// callback function definition +using callback_t = std::function; + +class Option; +class App; + +using Option_p = std::unique_ptr