diff --git a/aux/inc/WireCellAux/DftTools.h b/aux/inc/WireCellAux/DftTools.h new file mode 100644 index 000000000..9530bb3ef --- /dev/null +++ b/aux/inc/WireCellAux/DftTools.h @@ -0,0 +1,130 @@ +/** + This provides std::vector and Eigen::Array typed interface to an + IDFT. + */ + +#ifndef WIRECELL_AUX_DFTTOOLS +#define WIRECELL_AUX_DFTTOOLS + +#include "WireCellIface/IDFT.h" +#include +#include + +namespace WireCell::Aux { + + using complex_t = IDFT::complex_t; + + // std::vector based functions + + using real_vector_t = std::vector; + using complex_vector_t = std::vector; + + // 1D with vectors + + // Perform forward c2c transform on vector. + inline complex_vector_t fwd(const IDFT::pointer& dft, const complex_vector_t& seq) + { + complex_vector_t ret(seq.size()); + dft->fwd1d(seq.data(), ret.data(), ret.size()); + return ret; + } + + // Perform forward r2c transform on vector. + inline complex_vector_t fwd_r2c(const IDFT::pointer& dft, const real_vector_t& vec) + { + complex_vector_t cvec(vec.size()); + std::transform(vec.begin(), vec.end(), cvec.begin(), + [](float re) { return Aux::complex_t(re,0.0); } ); + return fwd(dft, cvec); + } + + // Perform inverse c2c transform on vector. + inline complex_vector_t inv(const IDFT::pointer& dft, const complex_vector_t& spec) + { + complex_vector_t ret(spec.size()); + dft->inv1d(spec.data(), ret.data(), ret.size()); + return ret; + } + + // Perform inverse c2r transform on vector. + inline real_vector_t inv_c2r(const IDFT::pointer& dft, const complex_vector_t& spec) + { + auto cvec = inv(dft, spec); + real_vector_t rvec(cvec.size()); + std::transform(cvec.begin(), cvec.end(), rvec.begin(), + [](const Aux::complex_t& c) { return std::real(c); }); + return rvec; + } + + // 1D high-level interface + + /// Convovle in1 and in2. Returned vecgtor has size sum of sizes + /// of in1 and in2 less one element in order to assure no periodic + /// aliasing. Caller need not (should not) pad either input. + /// Caller is free to truncate result as required. + real_vector_t convolve(const IDFT::pointer& dft, + const real_vector_t& in1, + const real_vector_t& in2); + + + /// Replace response res1 in meas with response res2. + /// + /// This will compute the FFT of all three, in frequency space will form: + /// + /// meas * resp2 / resp1 + /// + /// apply the inverse FFT and return its real part. + /// + /// The output vector is long enough to assure no periodic + /// aliasing. In general, caller should NOT pre-pad any input. + /// Any subsequent truncation of result is up to caller. + real_vector_t replace(const IDFT::pointer& dft, + const real_vector_t& meas, + const real_vector_t& res1, + const real_vector_t& res2); + + + // Eigen array based functions + + /// 2D array types. Note, use Array::cast() if you + /// need to convert rom real or arr.real() to convert to real. + using real_array_t = Eigen::ArrayXXf; + using complex_array_t = Eigen::ArrayXXcf; + + // 2D with Eigen arrays. Use eg arr.cast() to provde + // from real or arr.real()() to convert result to real. + + // Transform both dimesions. + complex_array_t fwd(const IDFT::pointer& dft, const complex_array_t& arr); + complex_array_t inv(const IDFT::pointer& dft, const complex_array_t& arr); + + // As above but internally convert input or output. These are + // just syntactic sugar hiding a .cast() or a .real() + // call. + complex_array_t fwd_r2c(const IDFT::pointer& dft, const real_array_t& arr); + real_array_t inv_c2r(const IDFT::pointer& dft, const complex_array_t& arr); + + // Transform a 2D array along one axis. + // + // The axis identifies the logical array "dimension" over which + // the transform is applied. For example, axis=1 means the + // transforms are applied along columns (ie, on a per-row basis). + // Note: this is the same convention as held by numpy.fft. + // + // The axis is interpreted in the "logical" sense Eigen arrays + // indexed as array(irow, icol). Ie, the dimension traversing + // rows is axis 0 and the dimension traversing columns is axis 1. + // Note: internal storage order of an Eigen array may differ from + // the logical order and indeed that of the array template type + // order. Neither is pertinent in setting the axis. + complex_array_t fwd(const IDFT::pointer& dft, const complex_array_t& arr, int axis); + complex_array_t inv(const IDFT::pointer& dft, const complex_array_t& arr, int axis); + + + // Fixme: possible additions + // - superposition of 2 reals for 2x speedup + // - r2c / c2r for 1b + +} + +#endif diff --git a/aux/inc/WireCellAux/FftwDFT.h b/aux/inc/WireCellAux/FftwDFT.h new file mode 100644 index 000000000..f26b5bcf9 --- /dev/null +++ b/aux/inc/WireCellAux/FftwDFT.h @@ -0,0 +1,57 @@ +#ifndef WIRECELLAUX_FFTWDFT +#define WIRECELLAUX_FFTWDFT + +#include "WireCellIface/IDFT.h" + +namespace WireCell::Aux { + + /** + The FftwDFT component provides IDFT based on FFTW3. + + All instances share a common thread-safe plan cache. There is + no benefit to using more than one instance in a process. + + See IDFT.h for important comments. + */ + class FftwDFT : public IDFT { + public: + + FftwDFT(); + virtual ~FftwDFT(); + + // 1d + + virtual + void fwd1d(const complex_t* in, complex_t* out, + int size) const; + + virtual + void inv1d(const complex_t* in, complex_t* out, + int size) const; + + virtual + void fwd1b(const complex_t* in, complex_t* out, + int nrows, int ncols, int axis) const; + + virtual + void inv1b(const complex_t* in, complex_t* out, + int nrows, int ncols, int axis) const; + + virtual + void fwd2d(const complex_t* in, complex_t* out, + int nrows, int ncols) const; + virtual + void inv2d(const complex_t* in, complex_t* out, + int nrows, int ncols) const; + + virtual + void transpose(const scalar_t* in, scalar_t* out, + int nrows, int ncols) const; + virtual + void transpose(const complex_t* in, complex_t* out, + int nrows, int ncols) const; + + }; +} + +#endif diff --git a/aux/inc/WireCellAux/Semaphore.h b/aux/inc/WireCellAux/Semaphore.h new file mode 100644 index 000000000..1385bf51d --- /dev/null +++ b/aux/inc/WireCellAux/Semaphore.h @@ -0,0 +1,34 @@ +/** Implement a semaphore component interace. */ + +#ifndef WIRECELLAUX_SEMAPHORE +#define WIRECELLAUX_SEMAPHORE + +#include "WireCellIface/IConfigurable.h" +#include "WireCellIface/ISemaphore.h" +#include "WireCellUtil/Semaphore.h" + + +namespace WireCell::Aux { + class Semaphore : public ISemaphore, + public IConfigurable + { + public: + Semaphore(); + virtual ~Semaphore(); + + // IConfigurable interface + virtual void configure(const WireCell::Configuration& config); + virtual WireCell::Configuration default_configuration() const; + + // ISemaphore + virtual void acquire() const; + virtual void release() const; + + private: + + mutable FastSemaphore m_sem; + + }; +} // namespace WireCell::Pytorch + +#endif // WIRECELLPYTORCH_TORCHSERVICE diff --git a/aux/inc/WireCellAux/SimpleTensor.h b/aux/inc/WireCellAux/SimpleTensor.h index 56c3ab905..7b267bdbe 100644 --- a/aux/inc/WireCellAux/SimpleTensor.h +++ b/aux/inc/WireCellAux/SimpleTensor.h @@ -2,7 +2,9 @@ #define WIRECELL_AUX_SIMPLETENSOR #include "WireCellIface/ITensor.h" + #include +#include namespace WireCell { @@ -13,14 +15,26 @@ namespace WireCell { public: typedef ElementType element_t; - SimpleTensor(const shape_t& shape) + // Create simple tensor, allocating space for data. If + // data given it must have at least as many elements as + // implied by shape and that span will be copied into + // allocated memory. + SimpleTensor(const shape_t& shape, + const element_t* data=nullptr, + const Configuration& md = Configuration()) { size_t nbytes = element_size(); - for (const auto& s : shape) { + m_shape = shape; + for (const auto& s : m_shape) { nbytes *= s; } - m_store.resize(nbytes); - m_shape = shape; + if (data) { + const std::byte* bytes = reinterpret_cast(data); + m_store.assign(bytes, bytes+nbytes); + } + else { + m_store.resize(nbytes); + } } virtual ~SimpleTensor() {} diff --git a/aux/inc/WireCellAux/TensorTools.h b/aux/inc/WireCellAux/TensorTools.h new file mode 100644 index 000000000..6765b457e --- /dev/null +++ b/aux/inc/WireCellAux/TensorTools.h @@ -0,0 +1,77 @@ +#ifndef WIRECELL_AUX_TENSORTOOLS +#define WIRECELL_AUX_TENSORTOOLS + +#include "WireCellIface/ITensor.h" +#include "WireCellIface/IDFT.h" +#include "WireCellUtil/Exceptions.h" + +#include +#include + +namespace WireCell::Aux { + + bool is_row_major(const ITensor::pointer& ten) { + if (ten->order().empty() or ten->order()[0] == 1) { + return true; + } + return false; + } + + template + bool is_type(const ITensor::pointer& ten) { + return (ten->element_type() == typeid(scalar_t)); + } + + + // Extract the underlying data array from the tensor as a vector. + // Caution: this ignores storage order hints and 1D or 2D will be + // flattened assuming C-ordering, aka row-major (if 2D). It + // throws ValueError on type mismatch. + template + std::vector asvec(const ITensor::pointer& ten) + { + if (ten->element_type() != typeid(element_type)) { + THROW(ValueError() << errmsg{"element type mismatch"}); + } + const element_type* data = (const element_type*)ten->data(); + const size_t nelems = ten->size()/sizeof(element_type); + return std::vector(data, data+nelems); + } + + // Extract the tensor data as an Eigen array. + template + Eigen::Array // this default is column-wise + asarray(const ITensor::pointer& tens) + { + if (tens->element_type() != typeid(element_type)) { + THROW(ValueError() << errmsg{"element type mismatch"}); + } + using ROWM = Eigen::Array; + using COLM = Eigen::Array; + + auto shape = tens->shape(); + int nrows, ncols; + if (shape.size() == 1) { + nrows = 1; + ncols = shape[0]; + } + else { + nrows = shape[0]; + ncols = shape[1]; + } + + // Eigen::Map is a non-const view of data but a copy happens + // on return. We need to temporarily break const correctness. + const element_type* cdata = reinterpret_cast(tens->data()); + element_type* mdata = const_cast(cdata); + + if (is_row_major(tens)) { + return Eigen::Map(mdata, nrows, ncols); + } + // column-major + return Eigen::Map(mdata, nrows, ncols); + } + +} + +#endif diff --git a/aux/src/DftTools.cxx b/aux/src/DftTools.cxx new file mode 100644 index 000000000..d9ebadac3 --- /dev/null +++ b/aux/src/DftTools.cxx @@ -0,0 +1,180 @@ +#include "WireCellAux/DftTools.h" +#include + +#include // debugging + + +using namespace WireCell; +using namespace WireCell::Aux; + +/* + Big fat warning to future me: Passing by reference means the input + array may carry the .IsRowMajor optimization for implementing + transpose(). An extra copy would remove that complication but this + interface tries to keep it. + */ + +using ROWM = Eigen::Array; +using COLM = Eigen::Array; + +template +Aux::complex_array_t doit(const Aux::complex_array_t& arr, trans func) +{ + const Aux::complex_t* in_data = arr.data(); + Aux::complex_vector_t out_vec(arr.rows()*arr.cols()); + + // std::cerr << "dft::doit: (" << arr.rows() << "," << arr.cols() << ") IsRowMajor:" << arr.IsRowMajor << std::endl; + + if (arr.IsRowMajor) { + func(in_data, out_vec.data(), arr.cols(), arr.rows()); + return Eigen::Map(out_vec.data(), arr.rows(), arr.cols()); + } + + func(in_data, out_vec.data(), arr.rows(), arr.cols()); + return Eigen::Map(out_vec.data(), arr.rows(), arr.cols()); +} + +Aux::complex_array_t Aux::fwd(const IDFT::pointer& dft, const Aux::complex_array_t& arr) +{ + return doit(arr, [&](const complex_t* in_data, + complex_t* out_data, + int nrows, int ncols) { + dft->fwd2d(in_data, out_data, nrows, ncols); + }); +} + +Aux::complex_array_t Aux::inv(const IDFT::pointer& dft, const Aux::complex_array_t& arr) +{ + return doit(arr, [&](const complex_t* in_data, + complex_t* out_data, + int nrows, int ncols) { + dft->inv2d(in_data, out_data, nrows, ncols); + }); +} + +// template +// Aux::complex_array_t doit1b(const Aux::complex_array_t& arr, int axis, trans func) +// { +// // We must provide a flat array with storage order such with +// // logical axis-major ordering. +// const Aux::complex_t* in_data = arr.data(); +// const int nrows = arr.rows(); // "logical" +// const int ncols = arr.cols(); // shape + +// // If storage order matches "axis-major" +// if ( (axis == 1 and arr.IsRowMajor) +// or +// (axis == 0 and not arr.IsRowMajor) ) { +// Aux::complex_vector_t out_vec(nrows*ncols); +// func(in_data, out_vec.data(), ncols, nrows); +// if (arr.IsRowMajor) { +// // note, returning makes a copy and will perform an actual +// // storage order transpose. +// return Eigen::Map(out_vec.data(), nrows, ncols); +// } +// return Eigen::Map(out_vec.data(), nrows, ncols); +// } + +// // Either we have row-major and want column-major storage order or +// // vice versa. + +// // Here, we must copy and not use "auto" to get actual storage +// // order transpose and avoid the IsRowMajor flip optimization. +// COLM flipped = arr.transpose(); +// COLM got = doit1b(flipped, (axis+1)%2, func); +// return got.transpose(); +// } + +// Implementation notes for fwd()/inv(): +// +// - We make an initial copy to get rid of any potential IsRowMajor +// optimization/confusion over storage order. This suffers a copy +// but we need to allocate return anyways. +// +// - We then have column-wise storage order but IDFT assumes row-wise +// - so we reverse (nrows, ncols) and meaning of axis. + +Aux::complex_array_t Aux::fwd(const IDFT::pointer& dft, + const Aux::complex_array_t& arr, + int axis) +{ + Aux::complex_array_t ret = arr; + dft->fwd1b(ret.data(), ret.data(), ret.cols(), ret.rows(), !axis); + return ret; +} + +Aux::complex_array_t Aux::inv(const IDFT::pointer& dft, + const Aux::complex_array_t& arr, + int axis) +{ + Aux::complex_array_t ret = arr; + dft->inv1b(ret.data(), ret.data(), ret.cols(), ret.rows(), !axis); + return ret; +} + +Aux::complex_array_t Aux::fwd_r2c(const IDFT::pointer& dft, + const real_array_t& arr) +{ + return Aux::fwd(dft, arr.cast()); +} +Aux::real_array_t Aux::inv_c2r(const IDFT::pointer& dft, + const complex_array_t& arr) +{ + return Aux::inv(dft, arr).real(); +} + + +Aux::real_vector_t Aux::convolve(const IDFT::pointer& dft, + const Aux::real_vector_t& in1, + const Aux::real_vector_t& in2) +{ + size_t size = in1.size() + in2.size() - 1; + Aux::complex_vector_t cin1(size,0), cin2(size,0); + + std::transform(in1.begin(), in1.end(), cin1.begin(), + [](float re) { return Aux::complex_t(re,0.0); } ); + std::transform(in2.begin(), in2.end(), cin2.begin(), + [](float re) { return Aux::complex_t(re,0.0); } ); + + dft->fwd1d(cin1.data(), cin1.data(), size); + dft->fwd1d(cin2.data(), cin2.data(), size); + + for (size_t ind=0; indfwd1d(cmeas.data(), cmeas.data(), size); + dft->fwd1d(cres1.data(), cres1.data(), size); + dft->fwd1d(cres2.data(), cres2.data(), size); + + for (size_t ind=0; ind +#include +#include +#include + +WIRECELL_FACTORY(FftwDFT, WireCell::Aux::FftwDFT, WireCell::IDFT) + + +using namespace WireCell; + +using plan_key_t = int64_t; +using plan_type = fftwf_plan; +using plan_map_t = std::unordered_map; +using plan_val_t = fftwf_complex; + +// Make a key by which a plan is known. dir should be FFTW_FORWARD or +// FFTW_BACKWARD and "axis" is -1 for all or in {0,1} for one of 2D. +// For 1D, use the default axis=-1. +// +// Imp note: The key is slightly over-specified as we keep one +// independent cache for each of the six methods. The "dir" is +// thus redundant. +static +plan_key_t make_key(const void * src, void * dst, int nrows, int ncols, int dir, int axis=-1) +{ + ++axis; // need three positive values, default is both axis + bool inverse = dir == FFTW_BACKWARD; + bool inplace = (dst==src); + bool aligned = ( (reinterpret_cast(src)&15) | (reinterpret_cast(dst)&15) ) == 0; + int64_t key = ( ( (((int64_t)nrows) << 32)| (ncols<<5 ) | (axis<<3) | (inverse<<2) | (inplace<<1) | aligned ) << 1 ) + 1; + return key; +} + +// Look up a plan by key or return NULL +static +plan_type get_plan(std::shared_mutex& mutex, plan_map_t& plans, plan_key_t key) +{ + std::shared_lock lock(mutex); + auto it = plans.find(key); + if (it == plans.end()) { + return NULL; + } + return it->second; +} + + +// #include // debugging + +using planner_function = std::function; + +// This wraps plan lookup, possible plan creation and subsequent plan +// execution so that we get thread-safe plan caching. +template +void doit(std::shared_mutex& mutex, plan_map_t& plans, plan_key_t key, + ValueType* src, ValueType* dst, + planner_function make_plan, + std::function exec_plan) +{ + auto plan = get_plan(mutex, plans, key); + if (!plan) { + std::unique_lock lock(mutex); + // Check again in case another thread snakes us. + auto it = plans.find(key); + if (it == plans.end()) { + //std::cerr << "make plan for " << key << std::endl; + plan = make_plan(); + plans[key] = plan; + } + else { + plan = it->second; + } + } + //fftwf_execute_dft(plan, src, dst); + exec_plan(plan, src, dst); +} + + +static +plan_val_t* pval_cast( const IDFT::complex_t * p) +{ + return const_cast( reinterpret_cast(p) ); +} + + +void Aux::FftwDFT::fwd1d(const complex_t* in, complex_t* out, int ncols) const +{ + static std::shared_mutex mutex; + static plan_map_t plans; + static const int dir = FFTW_FORWARD; + auto src = pval_cast(in); + auto dst = pval_cast(out); + auto key = make_key(src, dst, 1, ncols, dir); + doit(mutex, plans, key, src, dst, [&]( ) { + return fftwf_plan_dft_1d(ncols, src, dst, dir, FFTW_ESTIMATE|FFTW_PRESERVE_INPUT); + }, fftwf_execute_dft); +} +void Aux::FftwDFT::inv1d(const complex_t* in, complex_t* out, int ncols) const +{ + static std::shared_mutex mutex; + static plan_map_t plans; + static const int dir = FFTW_BACKWARD; + auto src = pval_cast(in); + auto dst = pval_cast(out); + auto key = make_key(src, dst, 1, ncols, dir); + + doit(mutex, plans, key, src, dst, [&]( ) { + return fftwf_plan_dft_1d(ncols, src, dst, dir, FFTW_ESTIMATE|FFTW_PRESERVE_INPUT); + }, fftwf_execute_dft); + + // Apply 1/n normalization + for (int ind=0; ind(mutex, plans, key, src, dst, [&]( ) { + return plan_1b(src, dst, nrows, ncols, dir, axis); + }, fftwf_execute_dft); +} + + +void Aux::FftwDFT::inv1b(const complex_t* in, complex_t* out, int nrows, int ncols, int axis) const +{ + static std::shared_mutex mutex; + static plan_map_t plans; + static const int dir = FFTW_BACKWARD; + auto src = pval_cast(in); + auto dst = pval_cast(out); + auto key = make_key(src, dst, nrows, ncols, dir, axis); + + doit(mutex, plans, key, src, dst, [&]( ) { + return plan_1b(src, dst, nrows, ncols, dir, axis); + }, fftwf_execute_dft); + + // 1/n normalization + const int norm = axis ? ncols : nrows; + const int ntot = ncols*nrows; + for (int ind=0; ind(mutex, plans, key, src, dst, [&]( ) { + return fftwf_plan_dft_2d(ncols, nrows, src, dst, dir, FFTW_ESTIMATE|FFTW_PRESERVE_INPUT); + }, fftwf_execute_dft); +} + + +void Aux::FftwDFT::inv2d(const complex_t* in, complex_t* out, int nrows, int ncols) const +{ + static std::shared_mutex mutex; + static plan_map_t plans; + static const int dir = FFTW_BACKWARD; + auto src = pval_cast(in); + auto dst = pval_cast(out); + auto key = make_key(src, dst, nrows, ncols, dir); + doit(mutex, plans, key, src, dst, [&]( ) { + return fftwf_plan_dft_2d(ncols, nrows, src, dst, dir, FFTW_ESTIMATE|FFTW_PRESERVE_INPUT); + }, fftwf_execute_dft); + + // reverse normalization + const int ntot = ncols*nrows; + for (int ind=0; ind(mutex, plans, key, src, dst, [&]( ) { + return transpose_plan_complex(src, dst, nrows, ncols); + }, fftwf_execute_dft); +} + +static +plan_type transpose_plan_real(float *in, float *out, int rows, int cols) +{ + const unsigned flags = FFTW_ESTIMATE; /* other flags are possible */ + fftw_iodim howmany_dims[2]; + + howmany_dims[0].n = rows; + howmany_dims[0].is = cols; + howmany_dims[0].os = 1; + + howmany_dims[1].n = cols; + howmany_dims[1].is = 1; + howmany_dims[1].os = rows; + + return fftwf_plan_guru_r2r(/*rank=*/ 0, /*dims=*/ NULL, + /*howmany_rank=*/ 2, howmany_dims, + in, out, /*kind=*/ NULL, flags); +} +void Aux::FftwDFT::transpose(const scalar_t* in, scalar_t* out, + int nrows, int ncols) const +{ + static std::shared_mutex mutex; + static plan_map_t plans; + static const int dir = 0; + auto src = const_cast(in); + auto dst = out; + auto key = make_key(src, dst, nrows, ncols, dir); + doit(mutex, plans, key, src, dst, [&]( ) { + return transpose_plan_real(src, dst, nrows, ncols); + }, fftwf_execute_r2r); +} + +Aux::FftwDFT::FftwDFT() +{ +} +Aux::FftwDFT::~FftwDFT() +{ +} + diff --git a/aux/src/Semaphore.cxx b/aux/src/Semaphore.cxx new file mode 100644 index 000000000..841debbb7 --- /dev/null +++ b/aux/src/Semaphore.cxx @@ -0,0 +1,51 @@ +#include "WireCellAux/Semaphore.h" + +#include "WireCellUtil/NamedFactory.h" +#include "WireCellUtil/Semaphore.h" + +WIRECELL_FACTORY(Semaphore, + WireCell::Aux::Semaphore, + WireCell::ISemaphore, + WireCell::IConfigurable) + +using namespace WireCell; + +Aux::Semaphore::Semaphore() + : m_sem(1) +{ +} +Aux::Semaphore::~Semaphore() +{ +} + +WireCell::Configuration Aux::Semaphore::default_configuration() const +{ + Configuration cfg; + + // The maximum allowed number concurrent calls to forward(). A + // value of unity means all calls will be serialized. When made + // smaller than the number of threads, the difference gives the + // number of threads that may block on the semaphore. + cfg["concurrency"] = 1; + + return cfg; +} + +void Aux::Semaphore::configure(const WireCell::Configuration& cfg) +{ + auto count = get(cfg, "concurrency", 1); + if (count < 1 ) { + count = 1; + } + m_sem.set_count(count); +} + +void Aux::Semaphore::acquire() const +{ + m_sem.acquire(); +} + +void Aux::Semaphore::release() const +{ + m_sem.release(); +} diff --git a/aux/src/TaggedTensorSetFrame.cxx b/aux/src/TaggedTensorSetFrame.cxx index 0a4c4db80..4cbbe7e23 100644 --- a/aux/src/TaggedTensorSetFrame.cxx +++ b/aux/src/TaggedTensorSetFrame.cxx @@ -4,6 +4,8 @@ #include "WireCellIface/SimpleFrame.h" #include "WireCellUtil/NamedFactory.h" +#include + WIRECELL_FACTORY(TaggedTensorSetFrame, WireCell::Aux::TaggedTensorSetFrame, WireCell::ITensorSetFrame, WireCell::IConfigurable) diff --git a/aux/test/aux_test_dft_helpers.h b/aux/test/aux_test_dft_helpers.h new file mode 100644 index 000000000..33d9564ec --- /dev/null +++ b/aux/test/aux_test_dft_helpers.h @@ -0,0 +1,249 @@ +// This is only for sharing some common code betweeen different +// aux/test/*.cxx tests. Not for "real" use. + +#include "WireCellUtil/NamedFactory.h" +#include "WireCellUtil/PluginManager.h" +#include "WireCellUtil/Exceptions.h" +#include "WireCellUtil/Persist.h" + +#include "WireCellIface/IConfigurable.h" +#include "WireCellIface/IDFT.h" + +#include + +#include // std::clock +#include + +#include +#include +#include +#include + +// note: likely will move in the future. +#include "custard/nlohmann/json.hpp" + +namespace WireCell::Aux::Test { + + using object_t = nlohmann::json; + + // probably move this to util + struct Stopwatch { + using clock = std::chrono::high_resolution_clock; + using time_point = clock::time_point; + using function_type = std::function; + + std::clock_t c_ini = std::clock(); + time_point t_ini = clock::now(); + + object_t results; + + Stopwatch(const object_t& first = object_t{}) { + (*this)([](){}, first); + } + + // Run the func, add timing info to a "stopwatch" attribute of + // data and save data to results. A pair of clock objects are + // saved, "clock" (std::clock) and "time" (std::chrono). Each + // have "start" and "elapsed" which are the number of + // nanoseconds from creation of stopwatch and for just this + // job, respectively. + void operator()(function_type func, object_t data = object_t{}) + { + auto c_now =std::clock(); + auto t_now = clock::now(); + func(); + auto c_fin =std::clock(); + auto t_fin = clock::now(); + + double dc_now = 1e9 * (c_now - c_ini) / ((double) CLOCKS_PER_SEC); + double dc_fin = 1e9 * (c_fin - c_now) / ((double) CLOCKS_PER_SEC); + double dt_now = std::chrono::duration_cast(t_now - t_ini).count(); + double dt_fin = std::chrono::duration_cast(t_fin - t_now).count(); + + data["stopwatch"]["clock"]["start"] = dc_now; + data["stopwatch"]["clock"]["elapsed"] = dc_fin; + data["stopwatch"]["time"]["start"] = dt_now; + data["stopwatch"]["time"]["elapsed"] = dt_fin; + + results.push_back(data); + } + + void save(const std::string& jsonfile) { + std::ofstream fp(jsonfile.c_str()); + fp << results.dump(4) << std::endl; + } + + + }; + + // fixme: add support for config + IDFT::pointer make_dft(const std::string& tn="FftwDFT", + const std::string& pi="WireCellAux", + Configuration cfg = Configuration()) + { + std::cerr << "Making DFT " << tn << " from plugin " << pi << std::endl; + + PluginManager& pm = PluginManager::instance(); + pm.add(pi); + + // create first + auto idft = Factory::lookup_tn(tn); + assert(idft); + // configure before use if configurable + auto icfg = Factory::find_maybe_tn(tn); + if (icfg) { + auto def_cfg = icfg->default_configuration(); + def_cfg = update(def_cfg, cfg); + icfg->configure(def_cfg); + } + return idft; + } + struct DftArgs { + std::string tn{"FftwDFT"}; + std::string pi{"WireCellAux"}; + std::string cfg_name{""}; + std::string output{""}; + std::vector positional; + Configuration cfg; + }; + + // remove command name from main()'s argc/argv[0] + int make_dft_args(DftArgs& args, int argc, char** argv) + { + // compilation times: po:19s, cli11:26s + namespace po = boost::program_options; + + po::options_description desc("Options"); + desc.add_options()("help,h", "IDFT tests [options] [arguments]") + ("output,o", po::value< std::string >(), "output file") + ("plugin,p", po::value< std::string >(), "plugin holding a IDFT") + ("typename,t", po::value< std::string >(), "type[:name] of the IDFT to use") + ("config,c", po::value< std::string >(), "configuration file") + ("args", po::value< std::vector >(), "positional arguments") + ; + po::positional_options_description pos_desc; + pos_desc.add("args", -1); + + auto parsed = po::command_line_parser(argc, argv) + .options(desc) + .positional(pos_desc) + .run(); + po::variables_map opts; + po::store(parsed, opts); + po::notify(opts); + + if (opts.count("help")) { + std::cout << desc << "\n"; + return 1; + } + + if (opts.count("output")) { + args.output = opts["output"].as< std::string> (); + } + if (opts.count("plugin") ) { + args.pi = opts["plugin"].as< std::string >(); + } + if (opts.count("typename")) { + args.tn = opts["typename"].as< std::string> (); + } + if (opts.count("args")) { + args.positional = opts["args"].as< std::vector >(); + } + if (opts.count("config")) { + args.cfg_name = opts["config"].as< std::string> (); + auto cfg = Persist::load(args.cfg_name); + + if (cfg.isArray()) { + for (auto one : cfg) { + std::string tn = get(one, "type"); + std::string n = get(one, "name", ""); + if (not n.empty()) { + tn = tn + ":" + n; + } + if (tn == args.tn) { + cfg = one["data"]; + break; + } + } + } + args.cfg = cfg; + } + return 0; + } + + const double default_eps = 1e-8; + const std::complex czero = 0.0; + const std::complex cone = 1.0; + + void assert_small(double val, double eps = default_eps) { + if (val < eps) { + return; + } + std::stringstream ss; + ss << "value " << val << " >= " << eps; + std::cerr << ss.str() << std::endl; + THROW(WireCell::ValueError() << errmsg{ss.str()}); + } + + // Assert the array has only value val at index and near zero elsewhere + template + void assert_impulse_at_index(const ValueType* vec, size_t size, + size_t index=0, ValueType val = 1.0) + { + ValueType tot = 0; + for (size_t ind=0; ind + void assert_impulse_at_index(const VectorType& vec, + size_t index=0, const typename VectorType::value_type& val = 1.0) + { + assert_impulse_at_index(vec.data(), vec.size(), index, val); + } + + // Assert all values in array are near given val + template + void assert_flat_value(const ValueType* vec, size_t size, ValueType val = 1.0) + { + ValueType tot = 0; + for (size_t ind=0; ind + void assert_flat_value(const VectorType& vec, const typename VectorType::value_type& val = 1.0) + { + assert_flat_value(vec.data(), vec.size(), val); + } + + // Print eigen array + template + void dump(std::string name, const array_type& arr) + { + std::cerr << name << ":(" << arr.rows() << "," << arr.cols() << ") row-major:" << arr.IsRowMajor << "\n"; + } + + + // Like std::iota, but dummer + template + void iota(ValueType* vec, size_t size, ValueType start = 0) + { + for (size_t ind=0; ind + +#include +#include +#include +#include + +using namespace WireCell; +using namespace WireCell::Stream; +using namespace WireCell::Aux::Test; + +using scalar_t = float; +using array_xxf = Eigen::Array; +using complex_t = std::complex; +using array_xxc = Eigen::Array; + +// may hold any dtype and shape +using pig_array = pigenc::File; +using array_store = std::map; + +using dft_op = std::function; +using op_lu_t = std::map; + +using vector_xf = std::vector; +using vector_xc = std::vector; + +template +static std::vector p2v(const pig_array& pa) +{ + if (pa.header().shape().size() != 1) { + throw std::runtime_error("p2v rank mismatch"); + } + auto vec = pa.as_vec(); + if (vec.empty()) { + throw std::runtime_error("p2v type mismatch"); + } + return vec; +} +template +pig_array v2p(const std::vector& vec) +{ + std::vector data((const char*)vec.data(), + (const char*)vec.data() + sizeof(Scalar)*vec.size()); + pig_array pa; + pa.set(data, {vec.size()}); + return pa; +} + + +template +Eigen::Array p2a(const pig_array& pa) +{ + if (pa.header().shape().size() != 2) { + throw std::runtime_error("p2a rank mismatch"); + } + Eigen::Array arr; + bool ok = pigenc::eigen::load(pa, arr); + if (!ok) { + throw std::runtime_error("p2a type mismatch"); + } + return arr; +} +template +pig_array a2p(const Eigen::Array& arr) +{ + pig_array pa; + pigenc::eigen::dump(pa, arr); + return pa; +} + + + +pig_array dispatch(const IDFT::pointer& dft, const pig_array& pa, const std::string& op) +{ + // vector + + if (op == "fwd1d") + return v2p(Aux::fwd(dft, p2v(pa))); + + if (op == "inv1d") + return v2p(Aux::inv(dft, p2v(pa))); + + if (op == "fwd1d_r2c") + return v2p(Aux::fwd_r2c(dft, p2v(pa))); + + if (op == "inv1d_c2r") + return v2p(Aux::inv_c2r(dft, p2v(pa))); + + // array + + if (op == "fwd2d") + return a2p(Aux::fwd(dft, p2a(pa))); + + if (op == "inv2d") + return a2p(Aux::inv(dft, p2a(pa))); + + if (op == "fwd2d_r2c") + return a2p(Aux::fwd_r2c(dft, p2a(pa))); + + if (op == "inv2d_c2r") + return a2p(Aux::inv_c2r(dft, p2a(pa))); + + if (op == "fwd1b0") + return a2p(Aux::fwd(dft, p2a(pa), 0)); + + if (op == "fwd1b1") + return a2p(Aux::fwd(dft, p2a(pa), 1)); + + if (op == "inv1b0") + return a2p(Aux::inv(dft, p2a(pa), 0)); + + if (op == "inv1b1") + return a2p(Aux::inv(dft, p2a(pa), 1)); + + if (op == "" or op == "noop" or op == "no-op") { + return pa; + } + + throw std::runtime_error("unsupported op: " + op); +} + +int main(int argc, char* argv[]) +{ + DftArgs args; + int rc = make_dft_args(args, argc, argv); + if (rc) { return rc; } + + if (args.positional.empty()) { + std::cerr << "need at least one input file" << std::endl; + return 0; + } + if (args.output.empty()) { + std::cerr << "need output file" << std::endl; + return 0; + } + if (args.cfg.empty()) { + std::cerr << "need configuration" << std::endl; + return 0; + } + std::cerr << args.cfg << std::endl; + + auto idft = make_dft(args.tn, args.pi, args.cfg); + + array_store arrs; + + // Slurp in arrays. + for (const auto& sname : args.positional) { + boost::iostreams::filtering_istream ins; + std::cerr << "openning: "< " << dst << std::endl; + auto darr = dispatch(idft, it->second, op); + + auto siz = darr.header().array_size(); + if (siz == 0) { + std::cerr << "failed: " << op << "(" << src << ") -> " << dst << " (zero size)\n"; + continue; + } + + + auto fsiz = darr.header().file_size(); + auto npy = dst.find(".npy"); + if (npy == std::string::npos) { + dst = dst + ".npy"; + } + std::cerr << "\twrite " << dst + << " with dtype=" << darr.header().dtype() + << " shape: ("; + for (auto dim : darr.header().shape()) { + std::cerr << " " << dim; + } + std::cerr << " ) to " << args.output << std::endl; + + custard::write(outs, dst, fsiz); + if (!outs) { + std::cerr << "stream error: " << strerror(errno) << std::endl; + std::cerr << "failed to write " << dst + << "(" << fsiz << ") to " + << args.output << "\n" << one << std::endl; + continue; + } + darr.write(outs); + outs.flush(); + } + + outs.pop(); + + return 0; +} diff --git a/aux/test/check_idft_bench.cxx b/aux/test/check_idft_bench.cxx new file mode 100644 index 000000000..a31029354 --- /dev/null +++ b/aux/test/check_idft_bench.cxx @@ -0,0 +1,134 @@ +/** + A simple benchmark of IDFT for payloads relevant to WCT + */ + +#include "aux_test_dft_helpers.h" + +using namespace WireCell; +using namespace WireCell::Aux::Test; + +using benchmark_function = std::function; +using complex_t = std::complex; + + +// benchmarks span outer product of: +// - in-place / out-place +// - 1d, 1b, 2d +// - sizes: perfect powers of 2 and with larger prime factors +// - use repitition numbers to keep each test roughly same runtime + +using transform_function = std::function; + +void ignore_exception(const complex_t* in, complex_t* out, transform_function func) +{ + try { + func(in, out); + } + catch (...) { + std::cerr << "exception ignored\n"; + } +} + +const int nominal = 100'000'000; +void doit(Stopwatch& sw, const std::string& name, int nrows, int ncols, transform_function func) +{ + const int size = nrows*ncols; + const int ntimes = std::max(1, nominal / size); + std::cerr << name << ": (" << nrows << "," << ncols << ") x "< in(size), out(size); + + sw([&](){ignore_exception(in.data(), in.data(), func);}, { + {"nrows",nrows}, {"ncols",ncols}, {"func",name}, {"ntimes",1}, {"first",true}, {"in-place",true}, + }); + + sw([&](){ + for (int count=0; count oned_sizes{128, 256, 500, 512, 1000, 1024, 2000, + 2048, 3000, 4096, 6000, 8192, 9375, 9503, 9592, 9595, 9600, + 10000, 16384}; + for (auto size : oned_sizes) { + doit(sw, "fwd1d", 1, size, [&](const complex_t* in, complex_t* out) { + idft->fwd1d(in, out, size); + }); + + doit(sw, "inv1d", 1, size, [&](const complex_t* in, complex_t* out) { + idft->inv1d(in, out, size); + }); + } + + // channel count from some detectors plus powers of 2 + std::vector> twod_sizes{ + {800,6000}, {960,6000}, // protodune u/v and w 3ms + {2400, 9595}, {3456, 9595}, // uboone u/v daq size + {1024, 1024}, {2048, 2048}, {4096, 4096}, // perfect powers of 2 + }; + for (auto& [nrows,ncols] : twod_sizes) { + + doit(sw, "fwd2d", nrows, ncols, [&](const complex_t* in, complex_t* out) { + idft->fwd2d(in, out, nrows, ncols); + }); + doit(sw, "inv2d", nrows, ncols, [&](const complex_t* in, complex_t* out) { + idft->inv2d(in, out, nrows, ncols); + }); + + doit(sw, "fwd1b0", nrows, ncols, [&](const complex_t* in, complex_t* out) { + idft->fwd1b(in, out, nrows, ncols, 0); + }); + doit(sw, "inv1b0", nrows, ncols, [&](const complex_t* in, complex_t* out) { + idft->inv1b(in, out, nrows, ncols, 0); + }); + doit(sw, "fwd1b1", nrows, ncols, [&](const complex_t* in, complex_t* out) { + idft->fwd1b(in, out, nrows, ncols, 1); + }); + doit(sw, "inv1b1", nrows, ncols, [&](const complex_t* in, complex_t* out) { + idft->inv1b(in, out, nrows, ncols, 1); + }); + } + + sw.save(fname); + + return 0; +} diff --git a/aux/test/check_idft_bench.sh b/aux/test/check_idft_bench.sh new file mode 100755 index 000000000..0003c9207 --- /dev/null +++ b/aux/test/check_idft_bench.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +# A do all script for IDFT benchmark with all known IDFTs +# Note, this will almost certainly fail if a systen does not have: +# +# - wire-cell-toolkit built at least under build/ +# - wire-cell-python built and in the environment +# - run this script in-place in the source +# - host has at exactly one GPU +# - GPU has enough memory +# +# even if fails, it documents what to run + +tstdir="$(dirname $(realpath $BASH_SOURCE))" +auxdir="$(dirname $tstdir)" +topdir="$(dirname $auxdir)" +blddir="$topdir/build" +cib="$blddir/aux/check_idft_bench" + +torchcfg="$tstdir/test_idft_pytorch.jsonnet" + +set -x +wirecell-aux run-idft-bench -o idft-bench-fftw-cpu.json $cib +wirecell-aux run-idft-bench -o idft-bench-torch-cpu.json -p WireCellPytorch -t TorchDFT $cib +wirecell-aux run-idft-bench -o idft-bench-torch-gpu.json -p WireCellPytorch -t TorchDFT -c $torchcfg $cib + +wirecell-aux plot-idft-bench -o idft-bench.pdf \ + idft-bench-fftw-cpu.json \ + idft-bench-torch-cpu.json \ + idft-bench-torch-gpu.json diff --git a/aux/test/test_dfttools.cxx b/aux/test/test_dfttools.cxx new file mode 100644 index 000000000..cfa96349c --- /dev/null +++ b/aux/test/test_dfttools.cxx @@ -0,0 +1,137 @@ +#include "aux_test_dft_helpers.h" + +#include "WireCellAux/DftTools.h" +#include "WireCellAux/FftwDFT.h" +#include "WireCellUtil/Waveform.h" + +#include +#include + +using namespace WireCell; +using namespace WireCell::Aux::Test; + +using real_t = float; +using RV = std::vector; +using complex_t = std::complex; +using CV = std::vector; + +void test_1d_impulse(IDFT::pointer dft, int size = 64) +{ + RV rimp(size, 0); + rimp[0] = 1.0; + + auto cimp = Aux::fwd(dft, Waveform::complex(rimp)); + assert_flat_value(cimp.data(), cimp.size()); + + RV rimp2 = Waveform::real(Aux::inv(dft, cimp)); + assert_impulse_at_index(rimp2.data(), rimp2.size()); +} + +using FA = Eigen::Array; +using CA = Eigen::Array; +using FARM = Eigen::Array; +using CARM = Eigen::Array; + +void test_2d_impulse(IDFT::pointer dft, int nrows=16, int ncols=8) +{ + const size_t size = nrows*ncols; + FA r = FA::Zero(nrows, ncols); + r(0,0) = 1.0; + dump("r", r); + assert_impulse_at_index(r.data(), size); + + CA rc = r.cast(); + dump("rc", rc); + assert_impulse_at_index(rc.data(), size); + + CA c = Aux::fwd(dft, rc); + dump("c", c); + assert_flat_value(c.data(), size); + + FA r2 = Aux::inv(dft, c).real(); + dump("r2", r2); + assert_impulse_at_index(r2.data(), size); +} + +void test_2d_eigen_transpose(IDFT::pointer dft) +{ + const int nrows=16; + const int ncols=8; + + // where the impulse lives (off axis) + const int imp_row = 1; + const int imp_col = 10; + + FA r = FA::Zero(nrows, ncols); // shape:(16,8) + dump("r", r); + + // do not remove the auto in this next line + auto rt = r.transpose(); // shape:(8,16) + dump("rt", rt); + rt(imp_row, imp_col) = 1.0; + + auto c = Aux::fwd(dft, rt.cast()); + dump("c", c); + + auto r2 = Aux::inv(dft, c).real(); + dump("r2",r2); + + // transpose access + const int nrowst = r2.rows(); + const int ncolst = r2.cols(); + + for (int irow=0; irow(), axis); + + dump("spectra", c); + std::cerr << c << std::endl; + + if (axis) { // transform along rows + CA ct = c.transpose(); // convert to along columns (native Eigen storage order) + c = ct; + std::swap(nrows, ncols); + dump("transpose", c); + std::cerr << c << std::endl; + } + + // first column has flat abs value of 1.0. + assert_flat_value(c.data(), nrows, complex_t(1,0)); + // rest should be flat, zero value + assert_flat_value(c.data()+nrows, nrows*ncols - nrows, complex_t(0,0)); + +} + +int main(int argc, char* argv[]) +{ + DftArgs args; + int rc = make_dft_args(args, argc, argv); + if (rc) { return rc; } + auto idft = make_dft(args.tn, args.pi, args.cfg); + + test_1d_impulse(idft); + test_2d_impulse(idft); + test_2d_eigen_transpose(idft); + test_1b(idft, 0); + test_1b(idft, 1); + + return 0; +} diff --git a/aux/test/test_idft.cxx b/aux/test/test_idft.cxx new file mode 100644 index 000000000..8e016bc2f --- /dev/null +++ b/aux/test/test_idft.cxx @@ -0,0 +1,240 @@ +// Test IDFT implementations. +#include "WireCellUtil/Waveform.h" + +#include "aux_test_dft_helpers.h" + +#include +#include +#include +#include +#include + +using namespace WireCell; +using namespace WireCell::Aux::Test; + +template +void dump(ValueType* data, int nrows, int ncols, std::string msg="") +{ + std::cerr << msg << "("< inter(size,0), freq(size,0); + + dft->fwd1d(inter.data(), freq.data(), inter.size()); + assert_flat_value(freq, czero); + dft->inv1d(freq.data(), inter.data(), freq.size()); + assert_flat_value(inter, czero); +} +static +void test_1d_impulse(IDFT::pointer dft, int size=1024) +{ + std::vector inter(size,0), freq(size,0), back(size,0); + inter[0] = 1.0; + + dft->fwd1d(inter.data(), freq.data(), freq.size()); + assert_flat_value(freq); + + dft->inv1d(freq.data(), back.data(), back.size()); + assert_impulse_at_index(back); +} + + + +static +void test_2d_zero(IDFT::pointer dft, int size = 1024) +{ + int stride=size, nstrides=size; + std::vector inter(stride*nstrides,0); + std::vector freq(stride*nstrides,0); + + dft->fwd2d(inter.data(), freq.data(), nstrides, stride); + assert_flat_value(inter, 0); + dft->inv2d(freq.data(), inter.data(), nstrides, stride); + assert_flat_value(freq, 0); +} +static +void test_2d_impulse(IDFT::pointer dft, int nrows=128, int ncols=128) +{ + const int size = nrows*ncols; + std::vector inter(size,0), freq(size,0), back(size,0); + inter[0] = 1.0; + dft->fwd2d(inter.data(), freq.data(), nrows, ncols); + assert_flat_value(freq); + + dft->inv2d(freq.data(), back.data(), nrows, ncols); + assert_impulse_at_index(back); + +} + + +static void assert_on_axis(const std::vector& freq, + int axis, int nrows=128, int ncols=128) +{ + for (int irow=0; irow inter(size,0), freq(size,0), back(size,0); + inter[0] = 1.0; + + dft->fwd1b(inter.data(), freq.data(), nrows, ncols, axis); + dump(freq.data(), nrows, ncols, "freq"); + assert_on_axis(freq, axis, nrows, ncols); + + dft->inv1b(freq.data(), back.data(), nrows, ncols, axis); + dump(back.data(), nrows, ncols, "back"); + assert_impulse_at_index(back, 0); + + + std::vector inplace(size,0); + inplace[0] = 1.0; + dft->fwd1b(inplace.data(), inplace.data(), nrows, ncols, axis); + assert_on_axis(inplace, axis, nrows, ncols); + + std::vector inback(inplace.begin(), inplace.end()); + dft->inv1b(inback.data(), inback.data(), nrows, ncols, axis); + assert_impulse_at_index(inback, 0); +} + +void fwdrev(IDFT::pointer dft, int id, int ntimes, int size) +{ + int stride=size, nstrides=size; + std::vector inter(stride*nstrides,0); + std::vector freq(stride*nstrides,0); + + // std::cerr << "running " << id << std::endl; + + while (ntimes) { + //std::cerr << ntimes << "\n"; + dft->fwd2d(inter.data(), freq.data(), nstrides, stride); + dft->inv2d(freq.data(), inter.data(), nstrides, stride); + + --ntimes; + auto tot = Waveform::sum(inter); + assert(std::real(tot) == 0); + } + //std::cerr << "finished " << id << std::endl; +} + +static +void test_2d_threads(IDFT::pointer dft, int nthreads, int nloops, int size = 1024) +{ + using namespace std::chrono; + + steady_clock::time_point t1 = steady_clock::now(); + + std::vector workers; + + //std::cerr << "Starting workers\n"; + for (int ind=0; ind dt1 = duration_cast>(t2 - t1); + std::cerr << "ndfts: " << nthreads*nloops + << " " << nthreads << " " << nloops + << " " << dt1.count() << std::endl; +} + + +template +void test_2d_transpose(IDFT::pointer dft, int nrows, int ncols) +{ + std::vector arr(nrows*ncols); + std::iota(arr.begin(), arr.end(), 0); + + std::vector arr2(nrows*ncols, 0); + std::vector arr3(arr.begin(), arr.end()); + + dft->transpose(arr.data(), arr2.data(), nrows, ncols); + dft->transpose(arr3.data(), arr3.data(), nrows, ncols); + + for (int irow=0; irow(idft, 2, 8); + test_2d_transpose(idft, 8, 2); + test_2d_transpose(idft, 2, 8); + test_2d_transpose(idft, 8, 2); + + std::vector sizes = {128,256,512,1024}; + for (auto size : sizes) { + int ndouble=3, ntot=2*16384/size; + while (ndouble) { + int nthread = 1< + using namespace WireCell; int main() diff --git a/aux/test/test_tensor_tools.cxx b/aux/test/test_tensor_tools.cxx new file mode 100644 index 000000000..933a777e7 --- /dev/null +++ b/aux/test/test_tensor_tools.cxx @@ -0,0 +1,102 @@ +#include "WireCellAux/TensorTools.h" +#include "WireCellAux/SimpleTensor.h" + + +#include +#include + +using real_t = float; +using RV = std::vector; +using complex_t = std::complex; +using CV = std::vector; +using RT = WireCell::Aux::SimpleTensor; +using CT = WireCell::Aux::SimpleTensor; + +// test fodder +const RV real_vector{0,1,2,3,4,5}; +const RV real_vector_cw{0,3,1,4,2,5}; +const CV complex_vector{{0,0},{1,1},{2,2},{3,3},{4,4},{5,5}}; +const WireCell::ITensor::shape_t shape{2,3}; + +using namespace WireCell; + +void test_is_type() +{ + auto rt = std::make_shared(shape, real_vector.data()); + assert (Aux::is_type(rt)); + assert (!Aux::is_type(rt)); +} + +void test_is_row_major() +{ + // ST actually does not let us do anything but C-order/row-major + auto rm = std::make_shared(shape, real_vector.data()); + assert(Aux::is_row_major(rm)); +} + +template +void assert_equal(const VectorType& v1, const VectorType& v2) +{ + assert(v1.size() == v2.size()); + for (size_t ind=0; ind(shape, real_vector.data()); + auto ct = std::make_shared(shape, complex_vector.data()); + auto got_rt = Aux::asvec(rt); + auto got_ct = Aux::asvec(ct); + assert_equal(real_vector, got_rt); + assert_equal(complex_vector, got_ct); + + try { + auto oops = Aux::asvec(rt); + } + catch (ValueError& err) { + } +} + +void test_asarray() +{ + // as array 2x2: (1d,2d) x (rw,cw) + + // make mutable copy to test that TT returns a copy + RV my_vec(real_vector.begin(), real_vector.end()); + + // test 2d + auto rt = std::make_shared(shape, my_vec.data()); + auto ra = Aux::asarray(rt); + auto shape = rt->shape(); + for (size_t irow = 0; irow < shape[0]; ++irow) { + for (size_t icol = 0; icol < shape[1]; ++icol) { + assert(ra(irow, icol) == my_vec[irow*shape[1] + icol]); + } + } + + // test 1d + const WireCell::ITensor::shape_t shape1d{6,}; + auto rt1d = std::make_shared(shape1d, my_vec.data()); + auto ra1d = Aux::asarray(rt1d); + for (size_t ind = 0; ind < shape[0]; ++ind) { + assert(ra1d(ind) == my_vec[ind]); + } + + // Assure the internal use of Eigen::Map leads to a copy on return + my_vec[0] = 42; + assert(ra(0,0) == 0); + assert(ra1d(0) == 0); +} + +int main() +{ + test_is_type(); + test_is_row_major(); + test_asvec(); + test_asarray(); + + return 0; +} diff --git a/cfg/pgrapher/common/helpers/aux.jsonnet b/cfg/pgrapher/common/helpers/aux.jsonnet index 0aa977a4a..e83b444c3 100644 --- a/cfg/pgrapher/common/helpers/aux.jsonnet +++ b/cfg/pgrapher/common/helpers/aux.jsonnet @@ -6,6 +6,9 @@ local wc = import "wirecell.jsonnet"; { + // Default DFT uses FFTW3 + dft : { type: "FftwDFT" }, + // Configure "wire" geometry and channel map to load from file wires(filename) :: { type:"WireSchemaFile", diff --git a/cfg/pgrapher/common/helpers/gen.jsonnet b/cfg/pgrapher/common/helpers/gen.jsonnet index 9bbaa6611..c9313ffa2 100644 --- a/cfg/pgrapher/common/helpers/gen.jsonnet +++ b/cfg/pgrapher/common/helpers/gen.jsonnet @@ -4,6 +4,8 @@ local wc = import "wirecell.jsonnet"; local pg = import "pgraph.jsonnet"; local u = import "utils.jsonnet"; +local aux = import "aux.jsonnet"; + { default_seeds: [0, 1, 2, 3, 4], @@ -75,7 +77,7 @@ local u = import "utils.jsonnet"; // fr is a field response object (see fr() above). // srs is list of "short response" config objects, eg cer() // lrs is list of "long response" config objects, eg rc() - pirs(fr, srs, lrs) :: [ { + pirs(fr, srs, lrs, dft=aux.dft) :: [ { type: "PlaneImpactResponse", name : std.toString(plane), data : { @@ -87,18 +89,20 @@ local u = import "utils.jsonnet"; long_responses: [wc.tn(r) for r in lrs], // this needs to be big enough to convolve RC long_padding: 1.5*wc.ms, + dft: wc.tn(dft), }, - uses: [fr] + srs + lrs, + uses: [dft, fr] + srs + lrs, } for plane in [0,1,2]], // signal simulation - signal(anode, pirs, daq, lar, rnd=$.random()) :: + signal(anode, pirs, daq, lar, rnd=$.random(), dft=aux.dft) :: pg.pipeline([ pg.pnode({ type:'DepoTransform', name: u.idents(anode), data: { rng: wc.tn(rnd), + dft: wc.tn(dft), anode: wc.tn(anode), pirs: [wc.tn(p) for p in pirs], fluctuate: true, @@ -109,7 +113,7 @@ local u = import "utils.jsonnet"; tick: daq.tick, nsigma: 3, }, - }, nin=1, nout=1, uses=pirs + [anode, rnd]), + }, nin=1, nout=1, uses=pirs + [anode, rnd, dft]), pg.pnode({ type: 'Reframer', @@ -126,7 +130,7 @@ local u = import "utils.jsonnet"; // Return a frame filter config that will add in noise. - noise(anode, filename, daq, chstat=null, rnd=$.random()) :: + noise(anode, filename, daq, chstat=null, rnd=$.random(), dft=aux.dft) :: local cs = if std.type(chstat) == "null" then {tn:"", uses:[]} else {tn:wc.tn(chstat), uses:[chstat]}; @@ -140,8 +144,9 @@ local u = import "utils.jsonnet"; nsamples: daq.nticks, period: daq.tick, wire_length_scale: 1.0*wc.cm, // optimization binning + dft: wc.tn(dft), }, - uses: [anode] + cs.uses, + uses: [anode, dft] + cs.uses, }; pg.pnode({ @@ -152,7 +157,8 @@ local u = import "utils.jsonnet"; model: wc.tn(noise_model), nsamples: daq.nticks, replacement_percentage: 0.02, // random optimization - }}, nin=1, nout=1, uses=[rnd, noise_model]), + dft: wc.tn(dft), + }}, nin=1, nout=1, uses=[rnd, noise_model, dft]), // digitizer simulation diff --git a/cfg/pgrapher/common/helpers/nf.jsonnet b/cfg/pgrapher/common/helpers/nf.jsonnet index 9c666f1b1..e86562396 100644 --- a/cfg/pgrapher/common/helpers/nf.jsonnet +++ b/cfg/pgrapher/common/helpers/nf.jsonnet @@ -2,13 +2,17 @@ local wc = import "wirecell.jsonnet"; local pg = import "pgraph.jsonnet"; local u = import "utils.jsonnet"; -function(anode, fr, chndb, nsamples, tick=0.5*wc.us, rms_cuts=[]) +local default_dft = { type: 'FftwDFT' }; + +function(anode, fr, chndb, nsamples, tick=0.5*wc.us, rms_cuts=[], dft=default_dft) local single = { type: 'pdOneChannelNoise', name: u.idents(anode), + uses: [dft], data: { noisedb: wc.tn(chndb), anode: wc.tn(anode), + dft: wc.tn(dft), resmp: [ ], }, diff --git a/cfg/pgrapher/common/helpers/sp.jsonnet b/cfg/pgrapher/common/helpers/sp.jsonnet index 38728d496..5327e6e84 100644 --- a/cfg/pgrapher/common/helpers/sp.jsonnet +++ b/cfg/pgrapher/common/helpers/sp.jsonnet @@ -3,6 +3,7 @@ local wc = import "wirecell.jsonnet"; local pg = import "pgraph.jsonnet"; +local aux = import "aux.jsonnet"; // Signal processing. @@ -10,7 +11,7 @@ local pg = import "pgraph.jsonnet"; // Note, spfilt are a list of filter objects which MUST match // hard-wired names in the C++, sorry. See, eg // pgrapher/experiment/pdsp/sp-filters.jsonnet. -function(anode, fieldresp, elecresp, spfilt, adcpermv, perchan=null, override={}) +function(anode, fieldresp, elecresp, spfilt, adcpermv, perchan=null, dft=aux.dft, override={}) local apaid = anode.data.ident; // if perchan file name is given we need to add this to a @@ -35,6 +36,7 @@ function(anode, fieldresp, elecresp, spfilt, adcpermv, perchan=null, override={} * Associated tuning in sp-filters.jsonnet */ anode: wc.tn(anode), + dft: wc.tn(dft), field_response: wc.tn(fieldresp), elecresponse: wc.tn(elecresp), ftoffset: 0.0, // default 0.0 @@ -81,4 +83,4 @@ function(anode, fieldresp, elecresp, spfilt, adcpermv, perchan=null, override={} isWrapped: false, // process_planes: [0, 2], } + override - }, nin=1, nout=1, uses=[anode, fieldresp, elecresp] + pc.uses + spfilt) + }, nin=1, nout=1, uses=[anode, dft, fieldresp, elecresp] + pc.uses + spfilt) diff --git a/cfg/pgrapher/common/helpers/utils.jsonnet b/cfg/pgrapher/common/helpers/utils.jsonnet index 19e31618b..08ed7781f 100644 --- a/cfg/pgrapher/common/helpers/utils.jsonnet +++ b/cfg/pgrapher/common/helpers/utils.jsonnet @@ -18,7 +18,7 @@ local pg = import "pgraph.jsonnet"; local app_plugins = { 'TbbFlow': ["WireCellTbb"], - 'PGrapher': ["WireCellPgraph"], + 'PGrapher': ["WireCellPgrapher"], }, main(graph, app='Pgrapher', extra_plugins = []) :: diff --git a/cfg/pgrapher/common/params.jsonnet b/cfg/pgrapher/common/params.jsonnet index 92ba7c797..d83fd95a2 100644 --- a/cfg/pgrapher/common/params.jsonnet +++ b/cfg/pgrapher/common/params.jsonnet @@ -1,19 +1,22 @@ -// This file is part of wire-cell-cfg. +// This file is part of wire-cell-toolkit/cfg/. // // This file provides a base data structure to define parameters that // span all currently supported WCT functionality. Not every -// parameter will be used and not all value here is valid. The -// parameters are named and factored into sub-objects in order to be -// sympathetic to how the C++ components are structured and name their -// configuration paramters. As such it's often possible to build a -// component configuration object by inheriting from one or more -// sub-objects in the parameter structure. For most jobs, this -// structure should be derived and overriden before being passed to -// functions that produce other configuration structures. +// parameter will be used and not every value here may be valid for +// your use and should be overridden. The parameters are named and +// factored into sub-objects in order to be sympathetic to how the C++ +// components are structured and name their configuration paramters. +// As such it's often possible to build a component configuration +// object by inheriting from one or more sub-objects in the parameter +// structure. For most jobs, this structure should be derived and +// overriden before being passed to functions that produce other +// configuration structures. +// local wc = import "wirecell.jsonnet"; { + // Parameters relevant to the bulk liquid argon volume. lar : { // Longitudinal diffusion constant diff --git a/cfg/pgrapher/common/sim/nodes.jsonnet b/cfg/pgrapher/common/sim/nodes.jsonnet index ee34b45b1..15d1c641d 100644 --- a/cfg/pgrapher/common/sim/nodes.jsonnet +++ b/cfg/pgrapher/common/sim/nodes.jsonnet @@ -69,6 +69,7 @@ function(params, tools) name:name, data: { rng: wc.tn(tools.random), + dft: wc.tn(tools.dft), anode: wc.tn(anode), pirs: std.map(function(pir) wc.tn(pir), pirs), fluctuate: params.sim.fluctuate, @@ -79,7 +80,7 @@ function(params, tools) tick: params.daq.tick, nsigma: 3, }, - }, nin=1, nout=1, uses=[anode, tools.random] + pirs), + }, nin=1, nout=1, uses=[anode, tools.random, tools.dft] + pirs), // This may look similar to above but above is expected to diverge make_depozipper :: function(name, anode, pirs) g.pnode({ @@ -261,9 +262,10 @@ function(params, tools) // fixme: these should probably be set from params. nsamples: 50, // number of samples of the response - truncate:true // result is extended by nsamples, tuncate clips that off + truncate:true, // result is extended by nsamples, tuncate clips that off + dft: wc.tn(tools.dft), } - }, nin=1, nout=1), + }, nin=1, nout=1, uses=[tools.dft]), local merge = g.pnode({ type: "FrameMerger", diff --git a/cfg/pgrapher/common/tools.jsonnet b/cfg/pgrapher/common/tools.jsonnet index cd9a6e4e5..f47c19a8c 100644 --- a/cfg/pgrapher/common/tools.jsonnet +++ b/cfg/pgrapher/common/tools.jsonnet @@ -1,13 +1,22 @@ - // This file provides a function which takes a params object (see // ../params/) and returns a data structure with a number of // sub-objects that may configure various WCT "tool" type componets // which are not INodes. +// Some attributes are merely default and you may wish to override +// them. For example, the default IDFT FftwDFT and to instead ues +// TorchDFT you may do something like: +// +// local default_tools = tools_maker(params) +// local tools = std.mergePatch(default_tools, +// {dft: {type: "TorchDFT", data: {device: "gpu"}}}); +// + local wc = import "wirecell.jsonnet"; function(params) { + // The IRandom pRNG random : { type: "Random", data: { @@ -15,6 +24,10 @@ function(params) seeds: [0,1,2,3,4], } }, + // The IDFT FFT implementation + dft : { + type: "FftwDFT", + }, // One FR per field file. fields : std.mapWithIndex(function (n, fname) { @@ -90,13 +103,15 @@ function(params) }, // there is one trio of PIRs (one per wire plane in a face) for - // each field response. + // each field response. WARNING/fixme: this sets the default DFT + // with no way to override! This config structure needs a redo! pirs : std.mapWithIndex(function (n, fr) [ { type: "PlaneImpactResponse", name : "PIR%splane%d" % [fr.name, plane], data : sim_response_binning { plane: plane, + dft: wc.tn($.dft), field_response: wc.tn(fr), // note twice we give rc so we have rc^2 in the final convolution short_responses: if params.sys_status == false @@ -112,7 +127,7 @@ function(params) else [wc.tn($.rc_resp), wc.tn($.rc_resp)], long_padding: 1.5*wc.ms, }, - uses: [fr, $.elec_resp, $.rc_resp, $.sys_resp], + uses: [$.dft, fr, $.elec_resp, $.rc_resp, $.sys_resp], } for plane in [0,1,2]], $.fields), // One anode per detector "volume" diff --git a/cfg/pgrapher/experiment/dune-vd/sim.jsonnet b/cfg/pgrapher/experiment/dune-vd/sim.jsonnet index 4e5b043c5..e58622135 100644 --- a/cfg/pgrapher/experiment/dune-vd/sim.jsonnet +++ b/cfg/pgrapher/experiment/dune-vd/sim.jsonnet @@ -33,13 +33,14 @@ function(params, tools) { name: "empericalnoise%s"% anode.name, data: { anode: wc.tn(anode), + dft: wc.tn(tools.dft), chanstat: if std.type(csdb) == "null" then "" else wc.tn(csdb), spectra_file: params.files.noise, nsamples: params.daq.nticks, period: params.daq.tick, wire_length_scale: 1.0*wc.cm, // optimization binning }, - uses: [anode] + if std.type(csdb) == "null" then [] else [csdb], + uses: [anode, tools.dft] + if std.type(csdb) == "null" then [] else [csdb], }, local noise_models = [make_noise_model(anode) for anode in tools.anodes], @@ -49,10 +50,11 @@ function(params, tools) { name: "addnoise%s"%[model.name], data: { rng: wc.tn(tools.random), + dft: wc.tn(tools.dft), model: wc.tn(model), nsamples: params.daq.nticks, replacement_percentage: 0.02, // random optimization - }}, nin=1, nout=1, uses=[model]), + }}, nin=1, nout=1, uses=[tools.random, tools.dft, model]), local noises = [add_noise(model) for model in noise_models], diff --git a/cfg/pgrapher/experiment/dune-vd/sp.jsonnet b/cfg/pgrapher/experiment/dune-vd/sp.jsonnet index 625ed9e99..5e5b708f8 100644 --- a/cfg/pgrapher/experiment/dune-vd/sp.jsonnet +++ b/cfg/pgrapher/experiment/dune-vd/sp.jsonnet @@ -47,6 +47,7 @@ function(params, tools, override = {}) { * Associated tuning in sp-filters.jsonnet */ anode: wc.tn(anode), + dft: wc.tn(tools.dft), field_response: wc.tn(tools.field), elecresponse: wc.tn(tools.elec_resp), ftoffset: 0.0, // default 0.0 @@ -77,6 +78,6 @@ function(params, tools, override = {}) { wiener_threshold_tag: 'threshold%d' % anode.data.ident, gauss_tag: 'gauss%d' % anode.data.ident, } + override, - }, nin=1, nout=1, uses=[anode, tools.field, tools.elec_resp] + pc.uses + spfilt), + }, nin=1, nout=1, uses=[anode, tools.dft, tools.field, tools.elec_resp] + pc.uses + spfilt), } diff --git a/cfg/pgrapher/experiment/dune-vd/wcls-nf-sp.jsonnet b/cfg/pgrapher/experiment/dune-vd/wcls-nf-sp.jsonnet index a0438bbf9..7a46327b5 100644 --- a/cfg/pgrapher/experiment/dune-vd/wcls-nf-sp.jsonnet +++ b/cfg/pgrapher/experiment/dune-vd/wcls-nf-sp.jsonnet @@ -131,9 +131,9 @@ local base = import 'chndb-base.jsonnet'; local chndb = [{ type: 'OmniChannelNoiseDB', name: 'ocndbperfect%d' % n, - // data: perfect(params, tools.anodes[n], tools.field, n), - data: base(params, tools.anodes[n], tools.field, n), - uses: [tools.anodes[n], tools.field], // pnode extension + // data: perfect(params, tools.anodes[n], tools.field, n) { dft:wc.tn(tools.dft) }, + data: base(params, tools.anodes[n], tools.field, n) { dft:wc.tn(tools.dft) }, + uses: [tools.anodes[n], tools.field, tools.dft], } for n in std.range(0, std.length(tools.anodes) - 1)]; // local nf_maker = import 'pgrapher/experiment/dune10kt-1x2x6/nf.jsonnet'; diff --git a/cfg/pgrapher/experiment/dune-vd/wcls-sim-drift-simchannel.jsonnet b/cfg/pgrapher/experiment/dune-vd/wcls-sim-drift-simchannel.jsonnet index 06aba27b3..59bf42d03 100644 --- a/cfg/pgrapher/experiment/dune-vd/wcls-sim-drift-simchannel.jsonnet +++ b/cfg/pgrapher/experiment/dune-vd/wcls-sim-drift-simchannel.jsonnet @@ -120,8 +120,8 @@ local perfect = import 'pgrapher/experiment/dune10kt-1x2x6/chndb-perfect.jsonnet local chndb = [{ type: 'OmniChannelNoiseDB', name: 'ocndbperfect%d' % n, - data: perfect(params, tools.anodes[n], tools.field, n), - uses: [tools.anodes[n], tools.field], // pnode extension + data: perfect(params, tools.anodes[n], tools.field, n) {dft:wc.tn(tools.dft)}, + uses: [tools.anodes[n], tools.field, tools.dft], } for n in anode_iota]; //local chndb_maker = import 'pgrapher/experiment/pdsp/chndb.jsonnet'; diff --git a/cfg/pgrapher/experiment/dune10kt-1x2x6/nf.jsonnet b/cfg/pgrapher/experiment/dune10kt-1x2x6/nf.jsonnet index 5dc74f593..8a24d3f4a 100644 --- a/cfg/pgrapher/experiment/dune10kt-1x2x6/nf.jsonnet +++ b/cfg/pgrapher/experiment/dune10kt-1x2x6/nf.jsonnet @@ -6,7 +6,7 @@ local wc = import 'wirecell.jsonnet'; function(params, anode, chndbobj, n, name='') { local status = { - type: 'mbOneChannelStatus', + type: std.trace("Warning MB in DUNE?", 'mbOneChannelStatus'), name: name, data: { Threshold: 3.5, @@ -17,7 +17,7 @@ function(params, anode, chndbobj, n, name='') }, }, local single = { - type: 'mbOneChannelNoise', + type: std.trace("Warning MB in DUNE?", 'mbOneChannelNoise'), name: name, data: { noisedb: wc.tn(chndbobj), @@ -25,7 +25,7 @@ function(params, anode, chndbobj, n, name='') }, }, local grouped = { - type: 'mbCoherentNoiseSub', + type: std.trace("Warning MB in DUNE?", 'mbCoherentNoiseSub'), name: name, data: { noisedb: wc.tn(chndbobj), diff --git a/cfg/pgrapher/experiment/dune10kt-1x2x6/sim.jsonnet b/cfg/pgrapher/experiment/dune10kt-1x2x6/sim.jsonnet index 27bcefb35..97fb8f680 100644 --- a/cfg/pgrapher/experiment/dune10kt-1x2x6/sim.jsonnet +++ b/cfg/pgrapher/experiment/dune10kt-1x2x6/sim.jsonnet @@ -46,13 +46,14 @@ function(params, tools) { name: "empericalnoise%s"% anode.name, data: { anode: wc.tn(anode), + dft: wc.tn(tools.dft), chanstat: if std.type(csdb) == "null" then "" else wc.tn(csdb), spectra_file: params.files.noise, nsamples: params.daq.nticks, period: params.daq.tick, wire_length_scale: 1.0*wc.cm, // optimization binning }, - uses: [anode] + if std.type(csdb) == "null" then [] else [csdb], + uses: [anode, tools.dft] + if std.type(csdb) == "null" then [] else [csdb], }, local noise_models = [make_noise_model(anode) for anode in tools.anodes], @@ -62,10 +63,11 @@ function(params, tools) { name: "addnoise%s"%[model.name], data: { rng: wc.tn(tools.random), + dft: wc.tn(tools.dft), model: wc.tn(model), nsamples: params.daq.nticks, replacement_percentage: 0.02, // random optimization - }}, nin=1, nout=1, uses=[model]), + }}, nin=1, nout=1, uses=[tools.random, tools.dft, model]), local noises = [add_noise(model) for model in noise_models], diff --git a/cfg/pgrapher/experiment/dune10kt-1x2x6/sp.jsonnet b/cfg/pgrapher/experiment/dune10kt-1x2x6/sp.jsonnet index 07cd43d63..a311f09ad 100644 --- a/cfg/pgrapher/experiment/dune10kt-1x2x6/sp.jsonnet +++ b/cfg/pgrapher/experiment/dune10kt-1x2x6/sp.jsonnet @@ -47,6 +47,7 @@ function(params, tools, override = {}) { * Associated tuning in sp-filters.jsonnet */ anode: wc.tn(anode), + dft: wc.tn(tools.dft), field_response: wc.tn(tools.field), elecresponse: wc.tn(tools.elec_resp), ftoffset: 0.0, // default 0.0 @@ -77,6 +78,6 @@ function(params, tools, override = {}) { wiener_threshold_tag: 'threshold%d' % anode.data.ident, gauss_tag: 'gauss%d' % anode.data.ident, } + override, - }, nin=1, nout=1, uses=[anode, tools.field, tools.elec_resp] + pc.uses + spfilt), + }, nin=1, nout=1, uses=[anode, tools.dft, tools.field, tools.elec_resp] + pc.uses + spfilt), } diff --git a/cfg/pgrapher/experiment/dune10kt-1x2x6/wcls-blip-sim-drift-simchannel.jsonnet b/cfg/pgrapher/experiment/dune10kt-1x2x6/wcls-blip-sim-drift-simchannel.jsonnet index d908a451f..3f28a734c 100644 --- a/cfg/pgrapher/experiment/dune10kt-1x2x6/wcls-blip-sim-drift-simchannel.jsonnet +++ b/cfg/pgrapher/experiment/dune10kt-1x2x6/wcls-blip-sim-drift-simchannel.jsonnet @@ -117,8 +117,8 @@ local perfect = import 'pgrapher/experiment/dune10kt-1x2x6/chndb-perfect.jsonnet local chndb = [{ type: 'OmniChannelNoiseDB', name: 'ocndbperfect%d' % n, - data: perfect(params, tools.anodes[n], tools.field, n), - uses: [tools.anodes[n], tools.field], // pnode extension + data: perfect(params, tools.anodes[n], tools.field, n){dft:wc.tn(tools.dft)}, + uses: [tools.anodes[n], tools.field, tools.dft], // pnode extension } for n in anode_iota]; //local chndb_maker = import 'pgrapher/experiment/pdsp/chndb.jsonnet'; diff --git a/cfg/pgrapher/experiment/dune10kt-1x2x6/wcls-nf-sp.jsonnet b/cfg/pgrapher/experiment/dune10kt-1x2x6/wcls-nf-sp.jsonnet index 0132824a5..6df82186f 100644 --- a/cfg/pgrapher/experiment/dune10kt-1x2x6/wcls-nf-sp.jsonnet +++ b/cfg/pgrapher/experiment/dune10kt-1x2x6/wcls-nf-sp.jsonnet @@ -123,8 +123,8 @@ local chndb = [{ type: 'OmniChannelNoiseDB', name: 'ocndbperfect%d' % n, // data: perfect(params, tools.anodes[n], tools.field, n), - data: base(params, tools.anodes[n], tools.field, n), - uses: [tools.anodes[n], tools.field], // pnode extension + data: base(params, tools.anodes[n], tools.field, n){dft:wc.tn(tools.dft)}, + uses: [tools.anodes[n], tools.field, tools.dft], } for n in std.range(0, std.length(tools.anodes) - 1)]; local nf_maker = import 'pgrapher/experiment/dune10kt-1x2x6/nf.jsonnet'; diff --git a/cfg/pgrapher/experiment/dune10kt-1x2x6/wcls-sim-drift-simchannel.jsonnet b/cfg/pgrapher/experiment/dune10kt-1x2x6/wcls-sim-drift-simchannel.jsonnet index 0437e98dc..d26a1f131 100644 --- a/cfg/pgrapher/experiment/dune10kt-1x2x6/wcls-sim-drift-simchannel.jsonnet +++ b/cfg/pgrapher/experiment/dune10kt-1x2x6/wcls-sim-drift-simchannel.jsonnet @@ -116,8 +116,8 @@ local perfect = import 'pgrapher/experiment/dune10kt-1x2x6/chndb-perfect.jsonnet local chndb = [{ type: 'OmniChannelNoiseDB', name: 'ocndbperfect%d' % n, - data: perfect(params, tools.anodes[n], tools.field, n), - uses: [tools.anodes[n], tools.field], // pnode extension + data: perfect(params, tools.anodes[n], tools.field, n){dft:wc.tn(tools.dft)}, + uses: [tools.anodes[n], tools.field, tools.dft], } for n in anode_iota]; //local chndb_maker = import 'pgrapher/experiment/pdsp/chndb.jsonnet'; diff --git a/cfg/pgrapher/experiment/dune10kt-1x2x6/wcls-sp.jsonnet b/cfg/pgrapher/experiment/dune10kt-1x2x6/wcls-sp.jsonnet index 22969c7ba..fbdf79232 100644 --- a/cfg/pgrapher/experiment/dune10kt-1x2x6/wcls-sp.jsonnet +++ b/cfg/pgrapher/experiment/dune10kt-1x2x6/wcls-sp.jsonnet @@ -130,8 +130,8 @@ local chndb = [{ type: 'OmniChannelNoiseDB', name: 'ocndbperfect%d' % n, // data: perfect(params, tools.anodes[n], tools.field, n), - data: base(params, tools.anodes[n], tools.field, n), - uses: [tools.anodes[n], tools.field], // pnode extension + data: base(params, tools.anodes[n], tools.field, n){dft:wc.tn(tools.dft)}, + uses: [tools.anodes[n], tools.field, tools.dft], } for n in std.range(0, std.length(tools.anodes) - 1)]; // local nf_maker = import 'pgrapher/experiment/pdsp/nf.jsonnet'; diff --git a/cfg/pgrapher/experiment/dune10kt-1x2x6/wct-sim-check.jsonnet b/cfg/pgrapher/experiment/dune10kt-1x2x6/wct-sim-check.jsonnet index c8c9ba691..7f72cf3b3 100644 --- a/cfg/pgrapher/experiment/dune10kt-1x2x6/wct-sim-check.jsonnet +++ b/cfg/pgrapher/experiment/dune10kt-1x2x6/wct-sim-check.jsonnet @@ -77,8 +77,8 @@ local perfect = import 'chndb-perfect.jsonnet'; local chndb = [{ type: 'OmniChannelNoiseDB', name: 'ocndbperfect%d' % n, - data: perfect(params, tools.anodes[n], tools.field), - uses: [tools.anodes[n], tools.field], // pnode extension + data: perfect(params, tools.anodes[n], tools.field){dft:wc.tn(tools.dft)}, + uses: [tools.anodes[n], tools.field, tools.dft], } for n in std.range(0, std.length(tools.anodes) - 1)]; //local chndb_maker = import 'pgrapher/experiment/pdsp/chndb.jsonnet'; diff --git a/cfg/pgrapher/experiment/icarus/nf.jsonnet b/cfg/pgrapher/experiment/icarus/nf.jsonnet index c41cc9743..fc92dc1d7 100644 --- a/cfg/pgrapher/experiment/icarus/nf.jsonnet +++ b/cfg/pgrapher/experiment/icarus/nf.jsonnet @@ -3,23 +3,29 @@ local g = import 'pgraph.jsonnet'; local wc = import 'wirecell.jsonnet'; -function(params, anode, chndbobj, n, name='') +local default_dft = { type: 'FftwDFT' }; + +function(params, anode, chndbobj, n, name='', dft=default_dft) { local single = { type: 'pdOneChannelNoise', name: name, + uses: [dft, chndbobj, anode], data: { noisedb: wc.tn(chndbobj), anode: wc.tn(anode), + dft: wc.tn(dft), }, }, local grouped = { type: 'mbCoherentNoiseSub', name: name, + uses: [dft, chndbobj, anode], data: { noisedb: wc.tn(chndbobj), anode: wc.tn(anode), + dft: wc.tn(dft), rms_threshold: 0.0, }, }, diff --git a/cfg/pgrapher/experiment/icarus/sim.jsonnet b/cfg/pgrapher/experiment/icarus/sim.jsonnet index 2d88cf005..bd578ea57 100644 --- a/cfg/pgrapher/experiment/icarus/sim.jsonnet +++ b/cfg/pgrapher/experiment/icarus/sim.jsonnet @@ -46,13 +46,14 @@ function(params, tools) { name: "empericalnoise-" + anode.name, data: { anode: wc.tn(anode), + dft: wc.tn(tools.dft), chanstat: if std.type(csdb) == "null" then "" else wc.tn(csdb), spectra_file: params.files.noise, nsamples: params.daq.nticks, period: params.daq.tick, wire_length_scale: 1.0*wc.cm, // optimization binning }, - uses: [anode] + if std.type(csdb) == "null" then [] else [csdb], + uses: [anode, tools.dft] + if std.type(csdb) == "null" then [] else [csdb], }, local noise_models = [make_noise_model(anode) for anode in tools.anodes], @@ -62,10 +63,11 @@ function(params, tools) { name: "addnoise-" + model.name, data: { rng: wc.tn(tools.random), + dft: wc.tn(tools.dft), model: wc.tn(model), nsamples: params.daq.nticks, replacement_percentage: 0.02, // random optimization - }}, nin=1, nout=1, uses=[model]), + }}, nin=1, nout=1, uses=[tools.random, tools.dft, model]), local noises = [add_noise(model) for model in noise_models], diff --git a/cfg/pgrapher/experiment/icarus/sp.jsonnet b/cfg/pgrapher/experiment/icarus/sp.jsonnet index 6e5c6871f..ab66d462c 100644 --- a/cfg/pgrapher/experiment/icarus/sp.jsonnet +++ b/cfg/pgrapher/experiment/icarus/sp.jsonnet @@ -21,6 +21,7 @@ function(params, tools, override = {}) { data: { // Many parameters omitted here. anode: wc.tn(anode), + dft: wc.tn(tools.dft), field_response: wc.tn(tools.field), ftoffset: 0.0, // default 0.0 ctoffset: 0.0*wc.microsecond, // default -8.0 @@ -67,6 +68,6 @@ function(params, tools, override = {}) { process_planes: [0, util.anode_split(anode.data.ident)], // balance the left and right split } + override, - }, nin=1, nout=1, uses=[anode, tools.field, tools.elec_resp] + pc.uses + spfilt), + }, nin=1, nout=1, uses=[anode, tools.dft, tools.field, tools.elec_resp] + pc.uses + spfilt), } diff --git a/cfg/pgrapher/experiment/icarus/wcls-decode-to-sig.jsonnet b/cfg/pgrapher/experiment/icarus/wcls-decode-to-sig.jsonnet index a950670b2..bf7e838de 100644 --- a/cfg/pgrapher/experiment/icarus/wcls-decode-to-sig.jsonnet +++ b/cfg/pgrapher/experiment/icarus/wcls-decode-to-sig.jsonnet @@ -122,8 +122,8 @@ local chndb = [{ type: 'OmniChannelNoiseDB', name: 'ocndbperfect%d' % n, // data: perfect(params, tools.anodes[n], tools.field, n), - data: base(params, tools.anodes[n], tools.field, n), - uses: [tools.anodes[n], tools.field], // pnode extension + data: base(params, tools.anodes[n], tools.field, n){dft:wc.tn(tools.dft)}, + uses: [tools.anodes[n], tools.field, tools.dft], } for n in std.range(0, std.length(tools.anodes) - 1)]; local nf_maker = import 'pgrapher/experiment/icarus/nf.jsonnet'; diff --git a/cfg/pgrapher/experiment/icarus/wcls-multitpc-sim-drift-simchannel-omit-noise.jsonnet b/cfg/pgrapher/experiment/icarus/wcls-multitpc-sim-drift-simchannel-omit-noise.jsonnet index 6545f6d63..69c60da92 100644 --- a/cfg/pgrapher/experiment/icarus/wcls-multitpc-sim-drift-simchannel-omit-noise.jsonnet +++ b/cfg/pgrapher/experiment/icarus/wcls-multitpc-sim-drift-simchannel-omit-noise.jsonnet @@ -123,8 +123,8 @@ local perfect = import 'pgrapher/experiment/icarus/chndb-base.jsonnet'; local chndb = [{ type: 'OmniChannelNoiseDB', name: 'ocndbperfect%d' % n, - data: perfect(params, tools.anodes[n], tools.field, n), - uses: [tools.anodes[n], tools.field], // pnode extension + data: perfect(params, tools.anodes[n], tools.field, n){dft:wc.tn(tools.dft)}, + uses: [tools.anodes[n], tools.field, tools.dft], } for n in anode_iota]; diff --git a/cfg/pgrapher/experiment/icarus/wcls-multitpc-sim-drift-simchannel.jsonnet b/cfg/pgrapher/experiment/icarus/wcls-multitpc-sim-drift-simchannel.jsonnet index 4817f0000..cc7949d39 100644 --- a/cfg/pgrapher/experiment/icarus/wcls-multitpc-sim-drift-simchannel.jsonnet +++ b/cfg/pgrapher/experiment/icarus/wcls-multitpc-sim-drift-simchannel.jsonnet @@ -123,8 +123,8 @@ local perfect = import 'pgrapher/experiment/icarus/chndb-base.jsonnet'; local chndb = [{ type: 'OmniChannelNoiseDB', name: 'ocndbperfect%d' % n, - data: perfect(params, tools.anodes[n], tools.field, n), - uses: [tools.anodes[n], tools.field], // pnode extension + data: perfect(params, tools.anodes[n], tools.field, n){dft:wc.tn(tools.dft)}, + uses: [tools.anodes[n], tools.field, tools.dft], } for n in anode_iota]; @@ -165,13 +165,14 @@ local make_noise_model = function(anode, csdb=null) { name: "empericalnoise-" + anode.name, data: { anode: wc.tn(anode), + dft: wc.tn(tools.dft), chanstat: if std.type(csdb) == "null" then "" else wc.tn(csdb), spectra_file: params.files.noise, nsamples: params.daq.nticks, period: params.daq.tick, wire_length_scale: 1.0*wc.cm, // optimization binning }, - uses: [anode] + if std.type(csdb) == "null" then [] else [csdb], + uses: [anode, tools.dft] + if std.type(csdb) == "null" then [] else [csdb], }; local noise_model = make_noise_model(mega_anode); local add_noise = function(model, n) g.pnode({ @@ -179,10 +180,11 @@ local add_noise = function(model, n) g.pnode({ name: "addnoise%d-" %n + model.name, data: { rng: wc.tn(tools.random), + dfg: wc.tn(tools.dft), model: wc.tn(model), nsamples: params.daq.nticks, replacement_percentage: 0.02, // random optimization - }}, nin=1, nout=1, uses=[model]); + }}, nin=1, nout=1, uses=[tools.random, tools.dft, model]); local noises = [add_noise(noise_model, n) for n in std.range(0,3)]; local add_coherent_noise = function(n) g.pnode({ @@ -191,11 +193,12 @@ local add_coherent_noise = function(n) g.pnode({ data: { spectra_file: params.files.coherent_noise, rng: wc.tn(tools.random), + dft: wc.tn(tools.dft), nsamples: params.daq.nticks, random_fluctuation_amplitude: 0.1, period: params.daq.tick, normalization: 1 - }}, nin=1, nout=1, uses=[]); + }}, nin=1, nout=1, uses=[tools.random, tools.dft]); local coherent_noises = [add_coherent_noise(n) for n in std.range(0,3)]; // local digitizer = sim.digitizer(mega_anode, name="digitizer", tag="orig"); diff --git a/cfg/pgrapher/experiment/icarus/wcls-sim-drift-simchannel.jsonnet b/cfg/pgrapher/experiment/icarus/wcls-sim-drift-simchannel.jsonnet index 68591c7bc..960e2666c 100644 --- a/cfg/pgrapher/experiment/icarus/wcls-sim-drift-simchannel.jsonnet +++ b/cfg/pgrapher/experiment/icarus/wcls-sim-drift-simchannel.jsonnet @@ -140,13 +140,14 @@ local make_noise_model = function(anode, csdb=null) { name: "empericalnoise-" + anode.name, data: { anode: wc.tn(anode), + dft: wc.tn(tools.dft), chanstat: if std.type(csdb) == "null" then "" else wc.tn(csdb), spectra_file: params.files.noise, nsamples: params.daq.nticks, period: params.daq.tick, wire_length_scale: 1.0*wc.cm, // optimization binning }, - uses: [anode] + if std.type(csdb) == "null" then [] else [csdb], + uses: [anode, tools.dft] + if std.type(csdb) == "null" then [] else [csdb], }; local noise_model = make_noise_model(mega_anode); local add_noise = function(model, n) g.pnode({ @@ -154,10 +155,11 @@ local add_noise = function(model, n) g.pnode({ name: "addnoise%d-" %n + model.name, data: { rng: wc.tn(tools.random), + dft: wc.tn(tools.dft), model: wc.tn(model), nsamples: params.daq.nticks, replacement_percentage: 0.02, // random optimization - }}, nin=1, nout=1, uses=[model]); + }}, nin=1, nout=1, uses=[tools.random, tools.dft, model]); local noises = [add_noise(noise_model, n) for n in std.range(0,3)]; // local digitizer = sim.digitizer(mega_anode, name="digitizer", tag="orig"); diff --git a/cfg/pgrapher/experiment/icarus/wct-coherent-noise.jsonnet b/cfg/pgrapher/experiment/icarus/wct-coherent-noise.jsonnet index 29f0a5895..73f3cb66d 100644 --- a/cfg/pgrapher/experiment/icarus/wct-coherent-noise.jsonnet +++ b/cfg/pgrapher/experiment/icarus/wct-coherent-noise.jsonnet @@ -48,8 +48,8 @@ local perfect = import 'pgrapher/experiment/icarus/chndb-base.jsonnet'; local chndb = [{ type: 'OmniChannelNoiseDB', name: 'ocndbperfect%d' % n, - data: perfect(params, tools.anodes[n], tools.field, n), - uses: [tools.anodes[n], tools.field], // pnode extension + data: perfect(params, tools.anodes[n], tools.field, n){dft:wc.tn(tools.dft)}, + uses: [tools.anodes[n], tools.field, tools.dft], } for n in anode_iota]; local nf_maker = import 'pgrapher/experiment/icarus/nf.jsonnet'; @@ -72,13 +72,14 @@ local make_noise_model = function(anode, csdb=null) { name: "empericalnoise-" + anode.name, data: { anode: wc.tn(anode), + dft: wc.tn(tools.dft), chanstat: if std.type(csdb) == "null" then "" else wc.tn(csdb), spectra_file: params.files.noise, nsamples: params.daq.nticks, period: params.daq.tick, wire_length_scale: 1.0*wc.cm, // optimization binning }, - uses: [anode] + if std.type(csdb) == "null" then [] else [csdb], + uses: [anode, tools.dft] + if std.type(csdb) == "null" then [] else [csdb], }; local noise_model = make_noise_model(mega_anode); local add_noise = function(model, n) g.pnode({ @@ -98,11 +99,12 @@ local add_coherent_noise = function(n) g.pnode({ data: { spectra_file: params.files.coherent_noise, rng: wc.tn(tools.random), + dft: wc.tn(tools.dft), nsamples: params.daq.nticks, random_fluctuation_amplitude: 0.1, period: params.daq.tick, normalization: 1 - }}, nin=1, nout=1, uses=[]); + }}, nin=1, nout=1, uses=[tools.random, tools.dft]); local coherent_noises = [add_coherent_noise(n) for n in std.range(0,3)]; // local digitizer = sim.digitizer(mega_anode, name="digitizer", tag="orig"); diff --git a/cfg/pgrapher/experiment/icarus/wct-sim-check.jsonnet b/cfg/pgrapher/experiment/icarus/wct-sim-check.jsonnet index 82548c030..0bf8b74f1 100644 --- a/cfg/pgrapher/experiment/icarus/wct-sim-check.jsonnet +++ b/cfg/pgrapher/experiment/icarus/wct-sim-check.jsonnet @@ -48,8 +48,8 @@ local perfect = import 'pgrapher/experiment/icarus/chndb-base.jsonnet'; local chndb = [{ type: 'OmniChannelNoiseDB', name: 'ocndbperfect%d' % n, - data: perfect(params, tools.anodes[n], tools.field, n), - uses: [tools.anodes[n], tools.field], // pnode extension + data: perfect(params, tools.anodes[n], tools.field, n){dft:wc.tn(tools.dft)}, + uses: [tools.anodes[n], tools.field, tools.dft], } for n in anode_iota]; local nf_maker = import 'pgrapher/experiment/icarus/nf.jsonnet'; @@ -73,13 +73,14 @@ local make_noise_model = function(anode, csdb=null) { name: "empericalnoise-" + anode.name, data: { anode: wc.tn(anode), + dft: wc.tn(tools.dft), chanstat: if std.type(csdb) == "null" then "" else wc.tn(csdb), spectra_file: params.files.noise, nsamples: params.daq.nticks, period: params.daq.tick, wire_length_scale: 1.0*wc.cm, // optimization binning }, - uses: [anode] + if std.type(csdb) == "null" then [] else [csdb], + uses: [anode, tools.dft] + if std.type(csdb) == "null" then [] else [csdb], }; local noise_model = make_noise_model(mega_anode); local add_noise = function(model, n) g.pnode({ @@ -87,10 +88,11 @@ local add_noise = function(model, n) g.pnode({ name: "addnoise%d-" %n + model.name, data: { rng: wc.tn(tools.random), + dft: wc.tn(tools.dft), model: wc.tn(model), nsamples: params.daq.nticks, replacement_percentage: 0.02, // random optimization - }}, nin=1, nout=1, uses=[model]); + }}, nin=1, nout=1, uses=[tools.random, tools.dft, model]); local noises = [add_noise(noise_model, n) for n in std.range(0,3)]; // local digitizer = sim.digitizer(mega_anode, name="digitizer", tag="orig"); diff --git a/cfg/pgrapher/experiment/iceberg/nf.jsonnet b/cfg/pgrapher/experiment/iceberg/nf.jsonnet index 5dc74f593..8a24d3f4a 100644 --- a/cfg/pgrapher/experiment/iceberg/nf.jsonnet +++ b/cfg/pgrapher/experiment/iceberg/nf.jsonnet @@ -6,7 +6,7 @@ local wc = import 'wirecell.jsonnet'; function(params, anode, chndbobj, n, name='') { local status = { - type: 'mbOneChannelStatus', + type: std.trace("Warning MB in DUNE?", 'mbOneChannelStatus'), name: name, data: { Threshold: 3.5, @@ -17,7 +17,7 @@ function(params, anode, chndbobj, n, name='') }, }, local single = { - type: 'mbOneChannelNoise', + type: std.trace("Warning MB in DUNE?", 'mbOneChannelNoise'), name: name, data: { noisedb: wc.tn(chndbobj), @@ -25,7 +25,7 @@ function(params, anode, chndbobj, n, name='') }, }, local grouped = { - type: 'mbCoherentNoiseSub', + type: std.trace("Warning MB in DUNE?", 'mbCoherentNoiseSub'), name: name, data: { noisedb: wc.tn(chndbobj), diff --git a/cfg/pgrapher/experiment/iceberg/sim.jsonnet b/cfg/pgrapher/experiment/iceberg/sim.jsonnet index 27bcefb35..97fb8f680 100644 --- a/cfg/pgrapher/experiment/iceberg/sim.jsonnet +++ b/cfg/pgrapher/experiment/iceberg/sim.jsonnet @@ -46,13 +46,14 @@ function(params, tools) { name: "empericalnoise%s"% anode.name, data: { anode: wc.tn(anode), + dft: wc.tn(tools.dft), chanstat: if std.type(csdb) == "null" then "" else wc.tn(csdb), spectra_file: params.files.noise, nsamples: params.daq.nticks, period: params.daq.tick, wire_length_scale: 1.0*wc.cm, // optimization binning }, - uses: [anode] + if std.type(csdb) == "null" then [] else [csdb], + uses: [anode, tools.dft] + if std.type(csdb) == "null" then [] else [csdb], }, local noise_models = [make_noise_model(anode) for anode in tools.anodes], @@ -62,10 +63,11 @@ function(params, tools) { name: "addnoise%s"%[model.name], data: { rng: wc.tn(tools.random), + dft: wc.tn(tools.dft), model: wc.tn(model), nsamples: params.daq.nticks, replacement_percentage: 0.02, // random optimization - }}, nin=1, nout=1, uses=[model]), + }}, nin=1, nout=1, uses=[tools.random, tools.dft, model]), local noises = [add_noise(model) for model in noise_models], diff --git a/cfg/pgrapher/experiment/iceberg/sp.jsonnet b/cfg/pgrapher/experiment/iceberg/sp.jsonnet index ffd9fee47..717231089 100644 --- a/cfg/pgrapher/experiment/iceberg/sp.jsonnet +++ b/cfg/pgrapher/experiment/iceberg/sp.jsonnet @@ -43,6 +43,7 @@ function(params, tools, override = {}) { * Associated tuning in sp-filters.jsonnet */ anode: wc.tn(anode), + dft: wc.tn(tools.dft), field_response: wc.tn(tools.field), elecresponse: wc.tn(tools.elec_resp), ftoffset: 0.0, // default 0.0 @@ -73,6 +74,6 @@ function(params, tools, override = {}) { wiener_threshold_tag: 'threshold%d' % anode.data.ident, gauss_tag: 'gauss%d' % anode.data.ident, } + override, - }, nin=1, nout=1, uses=[anode, tools.field, tools.elec_resp] + pc.uses + spfilt), + }, nin=1, nout=1, uses=[anode, tools.dft, tools.field, tools.elec_resp] + pc.uses + spfilt), } diff --git a/cfg/pgrapher/experiment/iceberg/wcls-nf-sp.jsonnet b/cfg/pgrapher/experiment/iceberg/wcls-nf-sp.jsonnet index 7c4e4b70a..e8a82b5f6 100644 --- a/cfg/pgrapher/experiment/iceberg/wcls-nf-sp.jsonnet +++ b/cfg/pgrapher/experiment/iceberg/wcls-nf-sp.jsonnet @@ -120,8 +120,8 @@ local chndb = [{ type: 'OmniChannelNoiseDB', name: 'ocndbperfect%d' % n, // data: perfect(params, tools.anodes[n], tools.field, n), - data: base(params, tools.anodes[n], tools.field, n), - uses: [tools.anodes[n], tools.field], // pnode extension + data: base(params, tools.anodes[n], tools.field, n){dft:wc.tn(tools.dft)}, + uses: [tools.anodes[n], tools.field, tools.dft], } for n in std.range(0, std.length(tools.anodes) - 1)]; local nf_maker = import 'pgrapher/experiment/iceberg/nf.jsonnet'; diff --git a/cfg/pgrapher/experiment/iceberg/wcls-sp.jsonnet b/cfg/pgrapher/experiment/iceberg/wcls-sp.jsonnet index 31de7f7a7..c1ba88b98 100644 --- a/cfg/pgrapher/experiment/iceberg/wcls-sp.jsonnet +++ b/cfg/pgrapher/experiment/iceberg/wcls-sp.jsonnet @@ -126,8 +126,8 @@ local chndb = [{ type: 'OmniChannelNoiseDB', name: 'ocndbperfect%d' % n, // data: perfect(params, tools.anodes[n], tools.field, n), - data: base(params, tools.anodes[n], tools.field, n), - uses: [tools.anodes[n], tools.field], // pnode extension + data: base(params, tools.anodes[n], tools.field, n){dft:wc.tn(tools.dft)}, + uses: [tools.anodes[n], tools.field, tools.dft], } for n in std.range(0, std.length(tools.anodes) - 1)]; // an empty omnibus noise filter diff --git a/cfg/pgrapher/experiment/pdsp/chndb.jsonnet b/cfg/pgrapher/experiment/pdsp/chndb.jsonnet index f8cafaa6b..5c7ce3a95 100644 --- a/cfg/pgrapher/experiment/pdsp/chndb.jsonnet +++ b/cfg/pgrapher/experiment/pdsp/chndb.jsonnet @@ -8,8 +8,8 @@ function(params, tools) { perfect(anode) :: { type:'OmniChannelNoiseDB', name: 'ocndbperfect-' + anode.name, - data: base(params, anode, tools.field, anode.data.ident), - uses: [anode, tools.field], + data: base(params, anode, tools.field, anode.data.ident){dft:wc.tn(tools.dft)}, + uses: [anode, tools.field, tools.dft], }, } diff --git a/cfg/pgrapher/experiment/pdsp/nf.jsonnet b/cfg/pgrapher/experiment/pdsp/nf.jsonnet index c9945ed2d..aff05f7d7 100644 --- a/cfg/pgrapher/experiment/pdsp/nf.jsonnet +++ b/cfg/pgrapher/experiment/pdsp/nf.jsonnet @@ -4,87 +4,95 @@ local g = import 'pgraph.jsonnet'; local wc = import 'wirecell.jsonnet'; local gainmap = import 'pgrapher/experiment/pdsp/chndb-rel-gain.jsonnet'; -function(params, anode, chndbobj, n, name='') - { +local default_dft = { type: 'FftwDFT' }; + +function(params, anode, chndbobj, n, name='', dft=default_dft) { local single = { - type: 'pdOneChannelNoise', - name: name, - data: { - noisedb: wc.tn(chndbobj), - anode: wc.tn(anode), - resmp: [ - {channels: std.range(2128, 2175), sample_from: 5996}, - {channels: std.range(1520, 1559), sample_from: 5996}, - {channels: std.range( 440, 479), sample_from: 5996}, - ], - }, + type: 'pdOneChannelNoise', + name: name, + uses: [dft, chndbobj, anode], + data: { + noisedb: wc.tn(chndbobj), + anode: wc.tn(anode), + dft: wc.tn(dft), + resmp: [ + {channels: std.range(2128, 2175), sample_from: 5996}, + {channels: std.range(1520, 1559), sample_from: 5996}, + {channels: std.range( 440, 479), sample_from: 5996}, + ], + }, }, local grouped = { - type: 'mbCoherentNoiseSub', - name: name, - data: { - noisedb: wc.tn(chndbobj), - anode: wc.tn(anode), - rms_threshold: 0.0, - }, + type: 'mbCoherentNoiseSub', + name: name, + uses: [dft, chndbobj, anode], + data: { + noisedb: wc.tn(chndbobj), + anode: wc.tn(anode), + dft: wc.tn(dft), + rms_threshold: 0.0, + }, }, local sticky = { - type: 'pdStickyCodeMitig', - name: name, - data: { - extra_stky: [ - {channels: std.range(n * 2560, (n + 1) * 2560 - 1), bits: [0,1,63]}, - {channels: [4], bits: [6] }, - {channels: [159], bits: [6] }, - {channels: [164], bits: [36] }, - {channels: [168], bits: [7] }, - {channels: [323], bits: [24] }, - {channels: [451], bits: [25] }, - ], - noisedb: wc.tn(chndbobj), - anode: wc.tn(anode), - stky_sig_like_val: 15.0, - stky_sig_like_rms: 2.0, - stky_max_len: 10, - }, + type: 'pdStickyCodeMitig', + name: name, + uses: [dft, chndbobj, anode], + data: { + extra_stky: [ + {channels: std.range(n * 2560, (n + 1) * 2560 - 1), bits: [0,1,63]}, + {channels: [4], bits: [6] }, + {channels: [159], bits: [6] }, + {channels: [164], bits: [36] }, + {channels: [168], bits: [7] }, + {channels: [323], bits: [24] }, + {channels: [451], bits: [25] }, + ], + noisedb: wc.tn(chndbobj), + anode: wc.tn(anode), + dft: wc.tn(dft), + stky_sig_like_val: 15.0, + stky_sig_like_rms: 2.0, + stky_max_len: 10, + }, }, local gaincalib = { - type: 'pdRelGainCalib', - name: name, - data: { - noisedb: wc.tn(chndbobj), - anode: wc.tn(anode), - rel_gain: gainmap.rel_gain, - }, + type: 'pdRelGainCalib', + name: name, + uses: [chndbobj, anode], + data: { + noisedb: wc.tn(chndbobj), + anode: wc.tn(anode), + rel_gain: gainmap.rel_gain, + }, }, local obnf = g.pnode({ - type: 'OmnibusNoiseFilter', - name: name, - data: { + type: 'OmnibusNoiseFilter', + name: name, + data: { - // Nonzero forces the number of ticks in the waveform - nticks: 0, + // Nonzero forces the number of ticks in the waveform + nticks: 0, - // channel bin ranges are ignored - // only when the channelmask is merged to `bad` - maskmap: {sticky: "bad", ledge: "bad", noisy: "bad"}, - channel_filters: [ - // wc.tn(sticky), - wc.tn(single), - // wc.tn(gaincalib), - ], - grouped_filters: [ - // wc.tn(grouped), - ], - channel_status_filters: [ - ], - noisedb: wc.tn(chndbobj), - intraces: 'orig%d' % n, // frame tag get all traces - outtraces: 'raw%d' % n, - }, + // channel bin ranges are ignored + // only when the channelmask is merged to `bad` + maskmap: {sticky: "bad", ledge: "bad", noisy: "bad"}, + channel_filters: [ + // wc.tn(sticky), + wc.tn(single), + // wc.tn(gaincalib), + ], + grouped_filters: [ + // wc.tn(grouped), + ], + channel_status_filters: [ + ], + noisedb: wc.tn(chndbobj), + intraces: 'orig%d' % n, // frame tag get all traces + outtraces: 'raw%d' % n, + }, }, uses=[chndbobj, anode, sticky, single, grouped, gaincalib], nin=1, nout=1), pipe: g.pipeline([obnf], name=name), - }.pipe +}.pipe diff --git a/cfg/pgrapher/experiment/pdsp/ocndb-perfect.jsonnet b/cfg/pgrapher/experiment/pdsp/ocndb-perfect.jsonnet index 05fe99752..7a048d84a 100644 --- a/cfg/pgrapher/experiment/pdsp/ocndb-perfect.jsonnet +++ b/cfg/pgrapher/experiment/pdsp/ocndb-perfect.jsonnet @@ -3,18 +3,21 @@ local wc = import "wirecell.jsonnet"; +local default_dft = { type: 'FftwDFT' }; + // The "perfect noise" database is one that is free of any // "special" considerations such as per channel variability. The // "official" perfect chndb depends on the official "chndb-base" // and that seems to be adulterated with specific settings. We // try to start fresh here. -function(anode, fr, nsamples, tick=0.5*wc.us) { +function(anode, fr, nsamples, tick=0.5*wc.us, dft=default_dft) { local apaid = anode.data.ident, type:'OmniChannelNoiseDB', name: std.toString(apaid), - uses: [anode, fr], + uses: [anode, fr, dft], data: { anode: wc.tn(anode), + dft: wc.tn(dft), field_response: wc.tn(fr), tick: tick, nsamples: nsamples, diff --git a/cfg/pgrapher/experiment/pdsp/sim.jsonnet b/cfg/pgrapher/experiment/pdsp/sim.jsonnet index d03dc5391..7ac5213ef 100644 --- a/cfg/pgrapher/experiment/pdsp/sim.jsonnet +++ b/cfg/pgrapher/experiment/pdsp/sim.jsonnet @@ -46,13 +46,14 @@ function(params, tools) { name: "empericalnoise-" + anode.name, data: { anode: wc.tn(anode), + dft: wc.tn(tools.dft), chanstat: if std.type(csdb) == "null" then "" else wc.tn(csdb), spectra_file: params.files.noise, nsamples: params.daq.nticks, period: params.daq.tick, wire_length_scale: 1.0*wc.cm, // optimization binning }, - uses: [anode] + if std.type(csdb) == "null" then [] else [csdb], + uses: [anode, tools.dft] + if std.type(csdb) == "null" then [] else [csdb], }, local noise_models = [make_noise_model(anode) for anode in tools.anodes], @@ -62,10 +63,11 @@ function(params, tools) { name: "addnoise-" + model.name, data: { rng: wc.tn(tools.random), + dft: wc.tn(tools.dft), model: wc.tn(model), nsamples: params.daq.nticks, replacement_percentage: 0.02, // random optimization - }}, nin=1, nout=1, uses=[model]), + }}, nin=1, nout=1, uses=[tools.random, tools.dft, model]), local noises = [add_noise(model) for model in noise_models], diff --git a/cfg/pgrapher/experiment/pdsp/sp.jsonnet b/cfg/pgrapher/experiment/pdsp/sp.jsonnet index 2c9a30339..37559412f 100644 --- a/cfg/pgrapher/experiment/pdsp/sp.jsonnet +++ b/cfg/pgrapher/experiment/pdsp/sp.jsonnet @@ -47,6 +47,7 @@ function(params, tools, override = {}) { * Associated tuning in sp-filters.jsonnet */ anode: wc.tn(anode), + dft: wc.tn(tools.dft), field_response: wc.tn(tools.field), elecresponse: wc.tn(tools.elec_resp), ftoffset: 0.0, // default 0.0 @@ -95,6 +96,6 @@ function(params, tools, override = {}) { // process_planes: [0, 2], } + override, - }, nin=1, nout=1, uses=[anode, tools.field, tools.elec_resp] + pc.uses + spfilt), + }, nin=1, nout=1, uses=[anode, tools.dft, tools.field, tools.elec_resp] + pc.uses + spfilt), } diff --git a/cfg/pgrapher/experiment/pdsp/wcls-nf-sp.jsonnet b/cfg/pgrapher/experiment/pdsp/wcls-nf-sp.jsonnet index a0759b370..c427d5899 100644 --- a/cfg/pgrapher/experiment/pdsp/wcls-nf-sp.jsonnet +++ b/cfg/pgrapher/experiment/pdsp/wcls-nf-sp.jsonnet @@ -123,8 +123,8 @@ local chndb = [{ type: 'OmniChannelNoiseDB', name: 'ocndbperfect%d' % n, // data: perfect(params, tools.anodes[n], tools.field, n), - data: base(params, tools.anodes[n], tools.field, n), - uses: [tools.anodes[n], tools.field], // pnode extension + data: base(params, tools.anodes[n], tools.field, n){dft:wc.tn(tools.dft)}, + uses: [tools.anodes[n], tools.field, tools.dft], } for n in std.range(0, std.length(tools.anodes) - 1)]; local nf_maker = import 'pgrapher/experiment/pdsp/nf.jsonnet'; diff --git a/cfg/pgrapher/experiment/pdsp/wcls-raw-to-sig.jsonnet b/cfg/pgrapher/experiment/pdsp/wcls-raw-to-sig.jsonnet index 037db3ca4..1640a131a 100644 --- a/cfg/pgrapher/experiment/pdsp/wcls-raw-to-sig.jsonnet +++ b/cfg/pgrapher/experiment/pdsp/wcls-raw-to-sig.jsonnet @@ -88,8 +88,8 @@ local chndb = [{ type: 'OmniChannelNoiseDB', name: 'ocndbperfect%d' % n, // data: perfect(params, tools.anodes[n], tools.field, n), - data: base(params, tools.anodes[n], tools.field, n), - uses: [tools.anodes[n], tools.field], // pnode extension + data: base(params, tools.anodes[n], tools.field, n){dft:wc.tn(tools.dft)}, + uses: [tools.anodes[n], tools.field, tools.dft], } for n in anode_iota]; local nf_maker = import 'pgrapher/experiment/pdsp/nf.jsonnet'; diff --git a/cfg/pgrapher/experiment/pdsp/wcls-sim-drift-simchannel.jsonnet b/cfg/pgrapher/experiment/pdsp/wcls-sim-drift-simchannel.jsonnet index 9d4e3b714..ed4b755dc 100644 --- a/cfg/pgrapher/experiment/pdsp/wcls-sim-drift-simchannel.jsonnet +++ b/cfg/pgrapher/experiment/pdsp/wcls-sim-drift-simchannel.jsonnet @@ -105,8 +105,8 @@ local perfect = import 'pgrapher/experiment/pdsp/chndb-perfect.jsonnet'; local chndb = [{ type: 'OmniChannelNoiseDB', name: 'ocndbperfect%d' % n, - data: perfect(params, tools.anodes[n], tools.field, n), - uses: [tools.anodes[n], tools.field], // pnode extension + data: perfect(params, tools.anodes[n], tools.field, n){dft:wc.tn(tools.dft)}, + uses: [tools.anodes[n], tools.field, tools.dft], } for n in anode_iota]; //local chndb_maker = import 'pgrapher/experiment/pdsp/chndb.jsonnet'; diff --git a/cfg/pgrapher/experiment/pdsp/wcls-sp.jsonnet b/cfg/pgrapher/experiment/pdsp/wcls-sp.jsonnet index 328504c6e..c54b44405 100644 --- a/cfg/pgrapher/experiment/pdsp/wcls-sp.jsonnet +++ b/cfg/pgrapher/experiment/pdsp/wcls-sp.jsonnet @@ -129,8 +129,8 @@ local chndb = [{ type: 'OmniChannelNoiseDB', name: 'ocndbperfect%d' % n, // data: perfect(params, tools.anodes[n], tools.field, n), - data: base(params, tools.anodes[n], tools.field, n), - uses: [tools.anodes[n], tools.field], // pnode extension + data: base(params, tools.anodes[n], tools.field, n){dft:wc.tn(tools.dft)}, + uses: [tools.anodes[n], tools.field, tools.dft], } for n in std.range(0, std.length(tools.anodes) - 1)]; // local nf_maker = import 'pgrapher/experiment/pdsp/nf.jsonnet'; diff --git a/cfg/pgrapher/experiment/pdsp/wct-sim-check.jsonnet b/cfg/pgrapher/experiment/pdsp/wct-sim-check.jsonnet index ddef2ccb2..9872f4cd6 100644 --- a/cfg/pgrapher/experiment/pdsp/wct-sim-check.jsonnet +++ b/cfg/pgrapher/experiment/pdsp/wct-sim-check.jsonnet @@ -99,8 +99,8 @@ local perfect = import 'pgrapher/experiment/pdsp/chndb-base.jsonnet'; local chndb = [{ type: 'OmniChannelNoiseDB', name: 'ocndbperfect%d' % n, - data: perfect(params, tools.anodes[n], tools.field, n), - uses: [tools.anodes[n], tools.field], // pnode extension + data: perfect(params, tools.anodes[n], tools.field, n){dft:wc.tn(tools.dft)}, + uses: [tools.anodes[n], tools.field, tools.dft], } for n in std.range(0, std.length(tools.anodes) - 1)]; //local chndb_maker = import 'pgrapher/experiment/pdsp/chndb.jsonnet'; diff --git a/cfg/pgrapher/experiment/sbnd/chndb.jsonnet b/cfg/pgrapher/experiment/sbnd/chndb.jsonnet index f8cafaa6b..5c7ce3a95 100644 --- a/cfg/pgrapher/experiment/sbnd/chndb.jsonnet +++ b/cfg/pgrapher/experiment/sbnd/chndb.jsonnet @@ -8,8 +8,8 @@ function(params, tools) { perfect(anode) :: { type:'OmniChannelNoiseDB', name: 'ocndbperfect-' + anode.name, - data: base(params, anode, tools.field, anode.data.ident), - uses: [anode, tools.field], + data: base(params, anode, tools.field, anode.data.ident){dft:wc.tn(tools.dft)}, + uses: [anode, tools.field, tools.dft], }, } diff --git a/cfg/pgrapher/experiment/sbnd/nf.jsonnet b/cfg/pgrapher/experiment/sbnd/nf.jsonnet index 0b8d95072..af61fc6b7 100644 --- a/cfg/pgrapher/experiment/sbnd/nf.jsonnet +++ b/cfg/pgrapher/experiment/sbnd/nf.jsonnet @@ -4,14 +4,18 @@ local g = import 'pgraph.jsonnet'; local wc = import 'wirecell.jsonnet'; local gainmap = import 'pgrapher/experiment/sbnd/chndb-rel-gain.jsonnet'; -function(params, anode, chndbobj, n, name='') +local default_dft = { type: 'FftwDFT' }; + +function(params, anode, chndbobj, n, name='', dft=default_dft) { local single = { type: 'pdOneChannelNoise', name: name, + uses: [dft, chndbobj, anode], data: { noisedb: wc.tn(chndbobj), anode: wc.tn(anode), + dft: wc.tn(dft), resmp: [ {channels: std.range(2128, 2175), sample_from: 5996}, {channels: std.range(1520, 1559), sample_from: 5996}, @@ -22,15 +26,18 @@ function(params, anode, chndbobj, n, name='') local grouped = { type: 'mbCoherentNoiseSub', name: name, + uses: [dft, chndbobj, anode], data: { noisedb: wc.tn(chndbobj), anode: wc.tn(anode), + dft: wc.tn(dft), rms_threshold: 0.0, }, }, local sticky = { type: 'pdStickyCodeMitig', name: name, + uses: [dft, chndbobj, anode], data: { extra_stky: [ {channels: std.range(n * 2560, (n + 1) * 2560 - 1), bits: [0,1,63]}, @@ -43,6 +50,7 @@ function(params, anode, chndbobj, n, name='') ], noisedb: wc.tn(chndbobj), anode: wc.tn(anode), + dft: wc.tn(dft), stky_sig_like_val: 15.0, stky_sig_like_rms: 2.0, stky_max_len: 10, diff --git a/cfg/pgrapher/experiment/sbnd/sim.jsonnet b/cfg/pgrapher/experiment/sbnd/sim.jsonnet index d03dc5391..7ac5213ef 100644 --- a/cfg/pgrapher/experiment/sbnd/sim.jsonnet +++ b/cfg/pgrapher/experiment/sbnd/sim.jsonnet @@ -46,13 +46,14 @@ function(params, tools) { name: "empericalnoise-" + anode.name, data: { anode: wc.tn(anode), + dft: wc.tn(tools.dft), chanstat: if std.type(csdb) == "null" then "" else wc.tn(csdb), spectra_file: params.files.noise, nsamples: params.daq.nticks, period: params.daq.tick, wire_length_scale: 1.0*wc.cm, // optimization binning }, - uses: [anode] + if std.type(csdb) == "null" then [] else [csdb], + uses: [anode, tools.dft] + if std.type(csdb) == "null" then [] else [csdb], }, local noise_models = [make_noise_model(anode) for anode in tools.anodes], @@ -62,10 +63,11 @@ function(params, tools) { name: "addnoise-" + model.name, data: { rng: wc.tn(tools.random), + dft: wc.tn(tools.dft), model: wc.tn(model), nsamples: params.daq.nticks, replacement_percentage: 0.02, // random optimization - }}, nin=1, nout=1, uses=[model]), + }}, nin=1, nout=1, uses=[tools.random, tools.dft, model]), local noises = [add_noise(model) for model in noise_models], diff --git a/cfg/pgrapher/experiment/sbnd/sp.jsonnet b/cfg/pgrapher/experiment/sbnd/sp.jsonnet index fd767ced4..f6bfe4fe7 100644 --- a/cfg/pgrapher/experiment/sbnd/sp.jsonnet +++ b/cfg/pgrapher/experiment/sbnd/sp.jsonnet @@ -47,6 +47,7 @@ function(params, tools, override = {}) { * Associated tuning in sp-filters.jsonnet */ anode: wc.tn(anode), + dft: wc.tn(tools.dft), field_response: wc.tn(tools.field), elecresponse: wc.tn(tools.elec_resp), ftoffset: 0.0, // default 0.0 @@ -95,6 +96,6 @@ function(params, tools, override = {}) { // process_planes: [0, 2], } + override, - }, nin=1, nout=1, uses=[anode, tools.field, tools.elec_resp] + pc.uses + spfilt), + }, nin=1, nout=1, uses=[anode, tools.dft, tools.field, tools.elec_resp] + pc.uses + spfilt), } diff --git a/cfg/pgrapher/experiment/sbnd/wcls-nf-sp.jsonnet b/cfg/pgrapher/experiment/sbnd/wcls-nf-sp.jsonnet index 9b5dbcde8..fc79aa23a 100644 --- a/cfg/pgrapher/experiment/sbnd/wcls-nf-sp.jsonnet +++ b/cfg/pgrapher/experiment/sbnd/wcls-nf-sp.jsonnet @@ -123,8 +123,8 @@ local chndb = [{ type: 'OmniChannelNoiseDB', name: 'ocndbperfect%d' % n, // data: perfect(params, tools.anodes[n], tools.field, n), - data: base(params, tools.anodes[n], tools.field, n), - uses: [tools.anodes[n], tools.field], // pnode extension + data: base(params, tools.anodes[n], tools.field, n){dft:wc.tn(tools.dft)}, + uses: [tools.anodes[n], tools.field, tools.dft], } for n in std.range(0, std.length(tools.anodes) - 1)]; // local nf_maker = import 'pgrapher/experiment/pdsp/nf.jsonnet'; diff --git a/cfg/pgrapher/experiment/sbnd/wcls-sim-drift-simchannel.jsonnet b/cfg/pgrapher/experiment/sbnd/wcls-sim-drift-simchannel.jsonnet index f9754299c..8d9aec3c6 100644 --- a/cfg/pgrapher/experiment/sbnd/wcls-sim-drift-simchannel.jsonnet +++ b/cfg/pgrapher/experiment/sbnd/wcls-sim-drift-simchannel.jsonnet @@ -100,8 +100,8 @@ local perfect = import 'pgrapher/experiment/sbnd/chndb-perfect.jsonnet'; local chndb = [{ type: 'OmniChannelNoiseDB', name: 'ocndbperfect%d' % n, - data: perfect(params, tools.anodes[n], tools.field, n), - uses: [tools.anodes[n], tools.field], // pnode extension + data: perfect(params, tools.anodes[n], tools.field, n){wc.tn(tools.dft)}, + uses: [tools.anodes[n], tools.field, tools.dft], } for n in anode_iota]; //local chndb_maker = import 'pgrapher/experiment/sbnd/chndb.jsonnet'; diff --git a/cfg/pgrapher/experiment/sbnd/wct-sim-check.jsonnet b/cfg/pgrapher/experiment/sbnd/wct-sim-check.jsonnet index ca40e10b1..24216ec06 100644 --- a/cfg/pgrapher/experiment/sbnd/wct-sim-check.jsonnet +++ b/cfg/pgrapher/experiment/sbnd/wct-sim-check.jsonnet @@ -58,8 +58,8 @@ local perfect = import 'pgrapher/experiment/sbnd/chndb-base.jsonnet'; local chndb = [{ type: 'OmniChannelNoiseDB', name: 'ocndbperfect%d' % n, - data: perfect(params, tools.anodes[n], tools.field, n), - uses: [tools.anodes[n], tools.field], // pnode extension + data: perfect(params, tools.anodes[n], tools.field, n){dft:wc.tn(tools.dft)}, + uses: [tools.anodes[n], tools.field, tools.dft], } for n in std.range(0, std.length(tools.anodes) - 1)]; //local chndb_maker = import 'pgrapher/experiment/sbnd/chndb.jsonnet'; diff --git a/cfg/pgrapher/experiment/uboone/chndb.jsonnet b/cfg/pgrapher/experiment/uboone/chndb.jsonnet index 209559aab..ddbc73640 100644 --- a/cfg/pgrapher/experiment/uboone/chndb.jsonnet +++ b/cfg/pgrapher/experiment/uboone/chndb.jsonnet @@ -11,11 +11,11 @@ function(params, tools) wct: function(epoch="before") { type: "OmniChannelNoiseDB", name: "ocndb%s"%epoch, - data : + data : {dft: wc.tn(tools.dft)} if epoch == "perfect" then perfect(params, tools.anode, tools.field) else base(params, tools.anode, tools.field, rms_cuts[epoch]), - uses: [tools.anode, tools.field], // pnode extension + uses: [tools.anode, tools.field, tools.dft], }, wcls: function(epoch="before") { diff --git a/cfg/pgrapher/experiment/uboone/nf.jsonnet b/cfg/pgrapher/experiment/uboone/nf.jsonnet index a030091a3..081f9e7c5 100644 --- a/cfg/pgrapher/experiment/uboone/nf.jsonnet +++ b/cfg/pgrapher/experiment/uboone/nf.jsonnet @@ -23,24 +23,30 @@ function(params, tools, chndbobj, name="") Window: 5, Nbins: 250, Cut: 14, - anode: wc.tn(tools.anode) - }, + anode: wc.tn(tools.anode), + dft: wc.tn(tools.dft), + }, + uses: [tools.anode, tools.dft], }, local single = { type: "mbOneChannelNoise", name:name, data: { noisedb: wc.tn(chndbobj), - anode: wc.tn(tools.anode) - } + anode: wc.tn(tools.anode), + dft: wc.tn(tools.dft), + }, + uses: [tools.anode, tools.dft, chndbobj], }, local grouped = { type: "mbCoherentNoiseSub", name:name, data: { noisedb: wc.tn(chndbobj), - anode: wc.tn(tools.anode) - } + anode: wc.tn(tools.anode), + dft: wc.tn(tools.dft), + }, + uses: [tools.anode, tools.dft, chndbobj], }, local obnf = g.pnode({ diff --git a/cfg/pgrapher/experiment/uboone/nodes.jsonnet b/cfg/pgrapher/experiment/uboone/nodes.jsonnet index 4386e02f3..83de781d7 100644 --- a/cfg/pgrapher/experiment/uboone/nodes.jsonnet +++ b/cfg/pgrapher/experiment/uboone/nodes.jsonnet @@ -98,18 +98,19 @@ local g = import "pgraph.jsonnet"; // Make a noise model bound to an anode and a channel status - local make_noise_model = function(anode, csdb) { + local make_noise_model = function(anode, csdb, dft={type:"FftwDFT"}) { type: "EmpiricalNoiseModel", name: "empericalnoise%s"% csdb.name, data: { anode: wc.tn(anode), + dft: wc.tn(dft), chanstat: wc.tn(csdb), spectra_file: params.files.noise, nsamples: params.daq.nticks, period: params.daq.tick, wire_length_scale: 1.0*wc.cm, // optimization binning }, - uses: [anode, csdb], + uses: [anode, dft, csdb], }, @@ -118,7 +119,8 @@ local g = import "pgraph.jsonnet"; type: "NoiseSource", name: "%s%s"%[anode.name, model.name], data: params.daq { - rng: wc.tn(tools.random), + rng: wc.tn(tools.random), // this is going to fail, is this file even used? + dft: wc.tn(dft), model: wc.tn(model), anode: wc.tn(anode), @@ -127,7 +129,7 @@ local g = import "pgraph.jsonnet"; readout_time: params.daq.readout_time, sample_period: params.daq.tick, first_frame_number: params.daq.first_frame_number, - }}, nin=0, nout=1, uses=[anode, model]), + }}, nin=0, nout=1, uses=[anode, model, dft]), local noise_summer = g.pnode({ diff --git a/cfg/pgrapher/experiment/uboone/sim.jsonnet b/cfg/pgrapher/experiment/uboone/sim.jsonnet index 45d9b15c4..1a46178c8 100644 --- a/cfg/pgrapher/experiment/uboone/sim.jsonnet +++ b/cfg/pgrapher/experiment/uboone/sim.jsonnet @@ -82,13 +82,14 @@ function(params, tools) name: "empericalnoise%s"% csdb.name, data: { anode: wc.tn(anode), + dft: wc.tn(tools.dft), chanstat: wc.tn(csdb), spectra_file: params.files.noise, nsamples: params.daq.nticks, period: params.daq.tick, wire_length_scale: 1.0*wc.cm, // optimization binning }, - uses: [anode, csdb], + uses: [anode, csdb, tools.dft], }, @@ -98,10 +99,11 @@ function(params, tools) name: "addnoise%s"%[model.name], data: { rng: wc.tn(tools.random), + dft: wc.tn(tools.dft), model: wc.tn(model), nsamples: params.daq.nticks, replacement_percentage: 0.02, // random optimization - }}, nin=1, nout=1, uses=[model]), + }}, nin=1, nout=1, uses=[tools.random, tools.dft, model]), ret: { signal : signal, diff --git a/cfg/pgrapher/experiment/uboone/sp.jsonnet b/cfg/pgrapher/experiment/uboone/sp.jsonnet index 03a4ac0eb..697289547 100644 --- a/cfg/pgrapher/experiment/uboone/sp.jsonnet +++ b/cfg/pgrapher/experiment/uboone/sp.jsonnet @@ -15,6 +15,7 @@ function(params, tools) { // codes a slew of SP filter component names which MUST // correctly match what is provided in sp-filters.jsonnet. anode: wc.tn(tools.anode), + dft: wc.tn(tools.dft), field_response: wc.tn(tools.field), elecresponse: wc.tn(tools.elec_resp), postgain: 1, // default 1.2 @@ -22,11 +23,13 @@ function(params, tools) { per_chan_resp: wc.tn(tools.perchanresp), fft_flag: 0, // 1 is faster but higher memory, 0 is slightly slower but lower memory } - }, nin=1,nout=1, uses=[tools.anode, tools.field, tools.elec_resp, tools.perchanresp] + import "sp-filters.jsonnet"), -local sigproc_uniform = g.pnode({ + }, nin=1,nout=1, uses=[tools.anode, tools.dft, tools.field, tools.elec_resp, tools.perchanresp] + import "sp-filters.jsonnet"), + + local sigproc_uniform = g.pnode({ type: "OmnibusSigProc", data: { anode: wc.tn(tools.anode), + dft: wc.tn(tools.dft), field_response: wc.tn(tools.field), elecresponse: wc.tn(tools.elec_resp), postgain: 1, // default 1.2 @@ -37,9 +40,10 @@ local sigproc_uniform = g.pnode({ // r_fake_signal_low_th: 300, // r_fake_signal_high_th: 600, } - }, nin=1,nout=1,uses=[tools.anode, tools.field, tools.elec_resp] + import "sp-filters.jsonnet"), -// ch-by-ch response correction in SP turn off by setting null input -local sigproc = if std.type(params.files.chresp)=='null' + }, nin=1,nout=1,uses=[tools.anode, tools.dft, tools.field, tools.elec_resp] + import "sp-filters.jsonnet"), + + // ch-by-ch response correction in SP turn off by setting null input + local sigproc = if std.type(params.files.chresp)=='null' then sigproc_uniform else sigproc_perchan, @@ -72,6 +76,7 @@ local sigproc = if std.type(params.files.chresp)=='null' local l1spfilter = g.pnode({ type: "L1SPFilter", data: { + dft: wc.tn(tools.dft), fields: wc.tn(tools.field), filter: [0.000305453, 0.000978027, 0.00277049, 0.00694322, 0.0153945, 0.0301973, 0.0524048, 0.0804588, 0.109289, 0.131334, 0.139629, @@ -103,7 +108,7 @@ local sigproc = if std.type(params.files.chresp)=='null' sigtag: "gauss", // trace tag of input signal outtag: "l1sp", // trace tag for output signal } - }, nin=1, nout=1, uses=[tools.field]), + }, nin=1, nout=1, uses=[tools.dft, tools.field]), // merge the split output from NF ("raw" tag) and just the "gauss" // from normal SP for input to L1SP diff --git a/cfg/test/test-pdsp-sim-sp-dnnroi.jsonnet b/cfg/test/test-pdsp-sim-sp-dnnroi.jsonnet index 631dcd2e7..427fa7419 100644 --- a/cfg/test/test-pdsp-sim-sp-dnnroi.jsonnet +++ b/cfg/test/test-pdsp-sim-sp-dnnroi.jsonnet @@ -11,6 +11,10 @@ local hs = import "pgrapher/common/helpers.jsonnet"; local wires = hs.aux.wires(params.files.wires); local anodes = hs.aux.anodes(wires, params.det.volumes); +// IDFT +//local dft = {type: 'FftwDFT'}; +local dft = {type: 'TorchDFT', data: { device: 'cpu' }}; + // simulation // kinematics: ideal line source @@ -35,7 +39,7 @@ local er = hs.aux.cer(params.elec.shaping, params.elec.gain, params.elec.postgain, params.daq.nticks, params.daq.tick); local rc = hs.aux.rc(1.0*wc.ms, params.daq.nticks, params.daq.tick); -local pirs = hs.gen.pirs(sim_fr, [er], [rc]); +local pirs = hs.gen.pirs(sim_fr, [er], [rc], dft=dft); // sp fr may differ from sim fr (as it does from real fr) local sp_fr = hs.aux.fr(if std.length(params.files.fields)>1 @@ -45,7 +49,7 @@ local sp_fr = hs.aux.fr(if std.length(params.files.fields)>1 local sp_filters = import "pgrapher/experiment/pdsp/sp-filters.jsonnet"; local adcpermv = hs.utils.adcpermv(params.adc); local chndbf = import "pgrapher/experiment/pdsp/ocndb-perfect.jsonnet"; -local chndb(anode) = chndbf(anode, sp_fr, params.nf.nsamples); +local chndb(anode) = chndbf(anode, sp_fr, params.nf.nsamples, dft=dft); local dnnroi_override = { sparse: true, use_roi_debug_mode: true, @@ -79,15 +83,14 @@ local out(anode, prefix, tag_pats, digitize=false, cap=false) = local anode_pipeline(anode, prefix) = pg.pipeline([ // sim - hs.gen.signal(anode, pirs, params.daq, params.lar, rnd=random), - hs.gen.noise(anode, params.files.noise, params.daq, rnd=random), + hs.gen.signal(anode, pirs, params.daq, params.lar, rnd=random, dft=dft), + hs.gen.noise(anode, params.files.noise, params.daq, rnd=random, dft=dft), hs.gen.digi(anode, params.adc), out(anode, prefix, ["orig"], true), // nf+sp - hs.nf(anode, sp_fr, chndb(anode), params.nf.nsamples, params.daq.tick), - hs.sp(anode, sp_fr, er, sp_filters, adcpermv, - override=dnnroi_override), + hs.nf(anode, sp_fr, chndb(anode), params.nf.nsamples, params.daq.tick, dft=dft), + hs.sp(anode, sp_fr, er, sp_filters, adcpermv, override=dnnroi_override, dft=dft), out(anode, prefix, ["wiener","gauss"]), // // dnnroi diff --git a/cfg/test/test-pdsp-sim-sp.jsonnet b/cfg/test/test-pdsp-sim-sp.jsonnet new file mode 100644 index 000000000..9dbb19bf6 --- /dev/null +++ b/cfg/test/test-pdsp-sim-sp.jsonnet @@ -0,0 +1,84 @@ +// This provides a main wire-cell config file to exercise +// sim+sigproc (no dnnroi). When run it will produce tar files of frames +// data as numpy arrays. Ionization pattern is from ideal line +// source. + +local wc = import "wirecell.jsonnet"; +local pg = import "pgraph.jsonnet"; +local params = import "pgrapher/experiment/pdsp/simparams.jsonnet"; +local hs = import "pgrapher/common/helpers.jsonnet"; + +local wires = hs.aux.wires(params.files.wires); +local anodes = hs.aux.anodes(wires, params.det.volumes); + +// IDFT +local dft = {type: 'FftwDFT'}; + +// simulation + +// kinematics: ideal line source +local tracklist = [ + { + time: 0, + charge: -5000, + ray: params.det.bounds, + }, +]; +local depos = pg.pipeline([ + hs.gen.track_depos(tracklist), + hs.gen.bagger(params.daq), +]); + +local random = hs.gen.random(); +local drifter = hs.gen.drifter(params.det.volumes,params.lar,random); + +// responses +local sim_fr = hs.aux.fr(params.files.fields[0]); +local er = hs.aux.cer(params.elec.shaping, params.elec.gain, + params.elec.postgain, + params.daq.nticks, params.daq.tick); +local rc = hs.aux.rc(1.0*wc.ms, params.daq.nticks, params.daq.tick); +local pirs = hs.gen.pirs(sim_fr, [er], [rc]); + +// sp fr may differ from sim fr (as it does from real fr) +local sp_fr = hs.aux.fr(if std.length(params.files.fields)>1 + then params.files.fields[1] + else params.files.fields[0]); + +local sp_filters = import "pgrapher/experiment/pdsp/sp-filters.jsonnet"; +local adcpermv = hs.utils.adcpermv(params.adc); +local chndbf = import "pgrapher/experiment/pdsp/ocndb-perfect.jsonnet"; +local chndb(anode) = chndbf(anode, sp_fr, params.nf.nsamples); + +// little function to return a frame file tap or sink (if cap is +// true). This bakes in PDSP-specific array bounds! +local out(anode, prefix, tag_pats, digitize=false, cap=false) = + local tags = [tp + std.toString(anode.data.ident) + for tp in tag_pats]; + local fname = prefix + "-" + + std.join("-", tags) + ".tar.bz2"; + local dense = hs.io.frame_bounds(2560, 6000, + 2560 * anode.data.ident); + if cap + then hs.io.frame_file_sink(fname, tags, digitize, dense) + else hs.io.frame_file_tap(fname, tags, digitize, dense); + + +local anode_pipeline(anode, prefix) = pg.pipeline([ + // sim + hs.gen.signal(anode, pirs, params.daq, params.lar, rnd=random), + hs.gen.noise(anode, params.files.noise, params.daq, rnd=random), + hs.gen.digi(anode, params.adc), + out(anode, prefix, ["orig"], true), + + // nf+sp + hs.nf(anode, sp_fr, chndb(anode), params.nf.nsamples, params.daq.tick), + hs.sp(anode, sp_fr, er, sp_filters, adcpermv), + out(anode, prefix, ["wiener","gauss"], cap=true), +]); + +function(prefix="test-pdsp-ssd") + local pipes = [ anode_pipeline(a, prefix) for a in anodes]; + local body = pg.fan.fanout('DepoSetFanout', pipes); + local graph = pg.pipeline([depos, drifter, body]); + hs.utils.main(graph, 'TbbFlow', ['WireCellPytorch']) diff --git a/cfg/test/test-pdsp-sim.jsonnet b/cfg/test/test-pdsp-sim.jsonnet new file mode 100644 index 000000000..d28b355b3 --- /dev/null +++ b/cfg/test/test-pdsp-sim.jsonnet @@ -0,0 +1,108 @@ +// This provides a main wire-cell config file to exercise +// sim+sigproc+dnnroi. When run it will produce tar files of frames +// data as numpy arrays. Ionization pattern is from ideal line +// source. + +local wc = import "wirecell.jsonnet"; +local pg = import "pgraph.jsonnet"; +local params = import "pgrapher/experiment/pdsp/simparams.jsonnet"; +local hs = import "pgrapher/common/helpers.jsonnet"; + +local wires = hs.aux.wires(params.files.wires); +local anodes = hs.aux.anodes(wires, params.det.volumes); + +// IDFT +local dft = {type: 'FftwDFT'}; + +// simulation + +// kinematics: ideal line source +local tracklist = [ + { + time: 0, + charge: -5000, + ray: params.det.bounds, + }, +]; +local depos = pg.pipeline([ + hs.gen.track_depos(tracklist), + hs.gen.bagger(params.daq), +]); + +local random = hs.gen.random(); +local drifter = hs.gen.drifter(params.det.volumes,params.lar,random); + +// responses +local sim_fr = hs.aux.fr(params.files.fields[0]); +local er = hs.aux.cer(params.elec.shaping, params.elec.gain, + params.elec.postgain, + params.daq.nticks, params.daq.tick); +local rc = hs.aux.rc(1.0*wc.ms, params.daq.nticks, params.daq.tick); +local pirs = hs.gen.pirs(sim_fr, [er], [rc]); + +// sp fr may differ from sim fr (as it does from real fr) +local sp_fr = hs.aux.fr(if std.length(params.files.fields)>1 + then params.files.fields[1] + else params.files.fields[0]); + +local sp_filters = import "pgrapher/experiment/pdsp/sp-filters.jsonnet"; +local adcpermv = hs.utils.adcpermv(params.adc); +local chndbf = import "pgrapher/experiment/pdsp/ocndb-perfect.jsonnet"; +local chndb(anode) = chndbf(anode, sp_fr, params.nf.nsamples); +local dnnroi_override = { + sparse: true, + use_roi_debug_mode: true, + use_multi_plane_protection: true, + process_planes: [0, 1, 2] +}; + +local ts = { + type: "TorchService", + name: "dnnroi", + data: { + model: "unet-l23-cosmic500-e50.ts", + device: "cpu", + concurrency: 1, + }, +}; + +// little function to return a frame file tap or sink (if cap is +// true). This bakes in PDSP-specific array bounds! +local out(anode, prefix, tag_pats, digitize=false, cap=false) = + local tags = [tp + std.toString(anode.data.ident) + for tp in tag_pats]; + local fname = prefix + "-" + + std.join("-", tags) + ".tar.bz2"; + local dense = hs.io.frame_bounds(2560, 6000, + 2560 * anode.data.ident); + if cap + then hs.io.frame_file_sink(fname, tags, digitize, dense) + else hs.io.frame_file_tap(fname, tags, digitize, dense); + + +local anode_pipeline(anode, prefix) = pg.pipeline([ + // sim + hs.gen.signal(anode, pirs, params.daq, params.lar, rnd=random), + // hs.gen.noise(anode, params.files.noise, params.daq, rnd=random), + hs.gen.digi(anode, params.adc), + out(anode, prefix, ["orig"], true), + + // // nf+sp + // hs.nf(anode, sp_fr, chndb(anode), params.nf.nsamples, params.daq.tick), + // hs.sp(anode, sp_fr, er, sp_filters, adcpermv, + // override=dnnroi_override), + // out(anode, prefix, ["wiener","gauss"]), + + // // // dnnroi + // hs.dnnroi(anode, ts, output_scale=1.2), + // out(anode, prefix, ["dnnsp"], cap=true), +]); + +function(prefix="test-pdsp-ssd") + local pipes = [ anode_pipeline(a, prefix) for a in anodes]; + local body = pg.fan.fanout('DepoSetFanout', pipes); + local graph = pg.pipeline([depos, drifter, body]); + hs.utils.main(graph, 'TbbFlow', ['WireCellPytorch']) + + + diff --git a/cfg/test/test_multiductor.jsonnet b/cfg/test/test_multiductor.jsonnet index c90f3e7d4..f170ca358 100644 --- a/cfg/test/test_multiductor.jsonnet +++ b/cfg/test/test_multiductor.jsonnet @@ -4,6 +4,8 @@ local wc = import "wirecell.jsonnet"; +local dft = {type:'FftwDFT'}; + // special wire-cell command line configuration just to save us typing. local cmdline = { @@ -144,10 +146,12 @@ local noise_model = { // fixme: replace this with various models for DUNE, for now, // just pretend to be microboone. anode: wc.tn(anode_nominal), + dft: dft, spectra_file: "microboone-noise-spectra-v2.json.bz2", chanstat: "StaticChannelStatus", nsamples: params.daq.ticks_per_readout, - } + }, + uses: [dft], }; local noise_source = { type: "NoiseSource", diff --git a/gen/inc/WireCellGen/AddCoherentNoise.h b/gen/inc/WireCellGen/AddCoherentNoise.h index 58b93c6db..087da8bd2 100644 --- a/gen/inc/WireCellGen/AddCoherentNoise.h +++ b/gen/inc/WireCellGen/AddCoherentNoise.h @@ -7,6 +7,7 @@ #include "WireCellIface/IFrameFilter.h" #include "WireCellIface/IConfigurable.h" #include "WireCellIface/IRandom.h" +#include "WireCellIface/IDFT.h" #include "WireCellIface/IChannelSpectrum.h" #include "WireCellUtil/Waveform.h" #include "WireCellUtil/Logging.h" @@ -36,6 +37,7 @@ namespace WireCell { typedef std::map>> noise_map_t; IRandom::pointer m_rng; + IDFT::pointer m_dft; std::string m_spectra_file, m_rng_tn; int m_nsamples; diff --git a/gen/inc/WireCellGen/AddNoise.h b/gen/inc/WireCellGen/AddNoise.h index 662c2f3c8..6ad18b954 100644 --- a/gen/inc/WireCellGen/AddNoise.h +++ b/gen/inc/WireCellGen/AddNoise.h @@ -10,6 +10,7 @@ #include "WireCellIface/IFrameFilter.h" #include "WireCellIface/IConfigurable.h" #include "WireCellIface/IRandom.h" +#include "WireCellIface/IDFT.h" #include "WireCellIface/IChannelSpectrum.h" #include "WireCellUtil/Waveform.h" #include "WireCellAux/Logger.h" @@ -35,6 +36,7 @@ namespace WireCell { private: IRandom::pointer m_rng; + IDFT::pointer m_dft; IChannelSpectrum::pointer m_model; std::string m_model_tn, m_rng_tn; diff --git a/gen/inc/WireCellGen/BinnedDiffusion.h b/gen/inc/WireCellGen/BinnedDiffusion.h index dc3a161e4..6366dc665 100644 --- a/gen/inc/WireCellGen/BinnedDiffusion.h +++ b/gen/inc/WireCellGen/BinnedDiffusion.h @@ -4,7 +4,9 @@ #include "WireCellUtil/Pimpos.h" #include "WireCellUtil/Point.h" #include "WireCellUtil/Units.h" + #include "WireCellIface/IDepo.h" +#include "WireCellIface/IDFT.h" #include "WireCellGen/ImpactData.h" @@ -45,7 +47,8 @@ namespace WireCell { /// Useful to client code to mark a calculation strategy. enum ImpactDataCalculationStrategy { constant = 1, linear = 2 }; - BinnedDiffusion(const Pimpos& pimpos, const Binning& tbins, double nsigma = 3.0, + BinnedDiffusion(const Pimpos& pimpos, const IDFT::pointer& dft, + const Binning& tbins, double nsigma = 3.0, IRandom::pointer fluctuate = nullptr, ImpactDataCalculationStrategy calcstrat = linear); const Pimpos& pimpos() const { return m_pimpos; } @@ -95,6 +98,7 @@ namespace WireCell { private: const Pimpos& m_pimpos; + const IDFT::pointer& m_dft; const Binning& m_tbins; double m_nsigma; diff --git a/gen/inc/WireCellGen/DepoTransform.h b/gen/inc/WireCellGen/DepoTransform.h index d2b8f9a0b..d58434dea 100644 --- a/gen/inc/WireCellGen/DepoTransform.h +++ b/gen/inc/WireCellGen/DepoTransform.h @@ -4,14 +4,16 @@ #ifndef WIRECELLGEN_DEPOTRANSFORM #define WIRECELLGEN_DEPOTRANSFORM +#include "WireCellAux/Logger.h" + #include "WireCellIface/IDepoFramer.h" #include "WireCellIface/IConfigurable.h" #include "WireCellIface/IRandom.h" +#include "WireCellIface/IDFT.h" #include "WireCellIface/IPlaneImpactResponse.h" #include "WireCellIface/IAnodePlane.h" #include "WireCellIface/WirePlaneId.h" #include "WireCellIface/IDepo.h" -#include "WireCellAux/Logger.h" namespace WireCell { namespace Gen { @@ -35,6 +37,7 @@ namespace WireCell { private: IAnodePlane::pointer m_anode; IRandom::pointer m_rng; + IDFT::pointer m_dft; std::vector m_pirs; double m_start_time; diff --git a/gen/inc/WireCellGen/DepoZipper.h b/gen/inc/WireCellGen/DepoZipper.h deleted file mode 100644 index bd9ec50e5..000000000 --- a/gen/inc/WireCellGen/DepoZipper.h +++ /dev/null @@ -1,43 +0,0 @@ -/** Make a frame from depos using an ImpactZipper. - - See also the very similar DepoTransform which is newer and faster. - */ - -#ifndef WIRECELLGEN_DEPOZIPPER -#define WIRECELLGEN_DEPOZIPPER - -#include "WireCellIface/IDepoFramer.h" -#include "WireCellIface/IConfigurable.h" -#include "WireCellIface/IRandom.h" -#include "WireCellIface/IPlaneImpactResponse.h" -#include "WireCellIface/IAnodePlane.h" - -namespace WireCell { - namespace Gen { - - class DepoZipper : public IDepoFramer, public IConfigurable { - public: - DepoZipper(); - virtual ~DepoZipper(); - - virtual bool operator()(const input_pointer& in, output_pointer& out); - - virtual void configure(const WireCell::Configuration& cfg); - virtual WireCell::Configuration default_configuration() const; - - private: - IAnodePlane::pointer m_anode; - IRandom::pointer m_rng; - std::vector m_pirs; - - double m_start_time; - double m_readout_time; - double m_tick; - double m_drift_speed; - double m_nsigma; - int m_frame_count; - }; - } // namespace Gen -} // namespace WireCell - -#endif diff --git a/gen/inc/WireCellGen/Ductor.h b/gen/inc/WireCellGen/Ductor.h deleted file mode 100644 index 60f348b4b..000000000 --- a/gen/inc/WireCellGen/Ductor.h +++ /dev/null @@ -1,67 +0,0 @@ -#ifndef WIRECELLGEN_DUCTOR -#define WIRECELLGEN_DUCTOR - -#include "WireCellUtil/Pimpos.h" -#include "WireCellUtil/Response.h" - -#include "WireCellIface/IConfigurable.h" -#include "WireCellIface/IDuctor.h" - -#include "WireCellIface/IAnodeFace.h" -#include "WireCellIface/IAnodePlane.h" -#include "WireCellIface/IPlaneImpactResponse.h" -#include "WireCellIface/IRandom.h" -#include "WireCellUtil/Logging.h" - -#include - -namespace WireCell { - namespace Gen { - - /** This IDuctor needs a Garfield2D field calculation data - * file in compressed JSON format as produced by Python module - * wirecell.sigproc.garfield. - */ - class Ductor : public IDuctor, public IConfigurable { - public: - Ductor(); - virtual ~Ductor(){}; - - // virtual void reset(); - virtual bool operator()(const input_pointer& depo, output_queue& frames); - - virtual void configure(const WireCell::Configuration& config); - virtual WireCell::Configuration default_configuration() const; - - protected: - // The "Type:Name" of the IAnodePlane (default is "AnodePlane") - std::string m_anode_tn; - std::string m_rng_tn; - std::vector m_pir_tns; - - IAnodePlane::pointer m_anode; - IRandom::pointer m_rng; - std::vector m_pirs; - - IDepo::vector m_depos; - - double m_start_time; - double m_readout_time; - double m_tick; - double m_drift_speed; - double m_nsigma; - bool m_fluctuate; - std::string m_mode; - - int m_frame_count; - std::string m_tag; - - virtual void process(output_queue& frames); - virtual ITrace::vector process_face(IAnodeFace::pointer face, const IDepo::vector& face_depos); - bool start_processing(const input_pointer& depo); - Log::logptr_t l; - }; - } // namespace Gen -} // namespace WireCell - -#endif diff --git a/gen/inc/WireCellGen/EmpiricalNoiseModel.h b/gen/inc/WireCellGen/EmpiricalNoiseModel.h index 513c85dac..424bfe33f 100644 --- a/gen/inc/WireCellGen/EmpiricalNoiseModel.h +++ b/gen/inc/WireCellGen/EmpiricalNoiseModel.h @@ -14,6 +14,7 @@ #include "WireCellIface/IChannelSpectrum.h" #include "WireCellIface/IConfigurable.h" #include "WireCellIface/IAnodePlane.h" +#include "WireCellIface/IDFT.h" #include "WireCellIface/IChannelStatus.h" #include "WireCellUtil/Units.h" @@ -86,6 +87,7 @@ namespace WireCell { private: IAnodePlane::pointer m_anode; IChannelStatus::pointer m_chanstat; + IDFT::pointer m_dft; std::string m_spectra_file; int m_nsamples; diff --git a/gen/inc/WireCellGen/ImpactData.h b/gen/inc/WireCellGen/ImpactData.h index d775a09ff..d4d5a90af 100644 --- a/gen/inc/WireCellGen/ImpactData.h +++ b/gen/inc/WireCellGen/ImpactData.h @@ -4,6 +4,9 @@ */ #include "WireCellUtil/Waveform.h" + +#include "WireCellIface/IDFT.h" + #include "WireCellGen/GaussianDiffusion.h" #include @@ -56,7 +59,7 @@ namespace WireCell { * linear or constant (all = 0.5), * and honoring the Gaussian distribution (diffusion). */ - void calculate(int nticks) const; + void calculate(const IDFT::pointer& dft, int nticks) const; /** Return the time domain waveform of drifted/diffused * charge at this impact position. See `calculate()`. */ diff --git a/gen/inc/WireCellGen/ImpactTransform.h b/gen/inc/WireCellGen/ImpactTransform.h index 466072498..c3029e0e1 100644 --- a/gen/inc/WireCellGen/ImpactTransform.h +++ b/gen/inc/WireCellGen/ImpactTransform.h @@ -1,8 +1,11 @@ #ifndef WIRECELL_IMPACTTRANSFORM #define WIRECELL_IMPACTTRANSFORM -#include "WireCellIface/IPlaneImpactResponse.h" #include "WireCellGen/BinnedDiffusion_transform.h" + +#include "WireCellIface/IPlaneImpactResponse.h" +#include "WireCellIface/IDFT.h" + #include "WireCellUtil/Array.h" #include @@ -15,6 +18,7 @@ namespace WireCell { */ class ImpactTransform { IPlaneImpactResponse::pointer m_pir; + IDFT::pointer m_dft; BinnedDiffusion_transform& m_bd; int m_num_group; // how many 2D convolution is needed @@ -32,6 +36,7 @@ namespace WireCell { public: ImpactTransform(IPlaneImpactResponse::pointer pir, + const IDFT::pointer& dft, BinnedDiffusion_transform& bd); virtual ~ImpactTransform(); diff --git a/gen/inc/WireCellGen/ImpactZipper.h b/gen/inc/WireCellGen/ImpactZipper.h deleted file mode 100644 index ca9d8db60..000000000 --- a/gen/inc/WireCellGen/ImpactZipper.h +++ /dev/null @@ -1,38 +0,0 @@ -#ifndef WIRECELL_IMPACTZIPPER -#define WIRECELL_IMPACTZIPPER - -#include "WireCellIface/IPlaneImpactResponse.h" -#include "WireCellGen/BinnedDiffusion.h" - -namespace WireCell { - namespace Gen { - - /** An ImpactZipper "zips" up through all the impact positions - * along a wire plane convolving the response functions and - * the local drifted charge distribution producing a waveform - * on each central wire. - */ - class ImpactZipper { - IPlaneImpactResponse::pointer m_pir; - BinnedDiffusion& m_bd; - - public: - ImpactZipper(IPlaneImpactResponse::pointer pir, BinnedDiffusion& bd); - virtual ~ImpactZipper(); - - /// Return the wire's waveform. If the response functions - /// are just field response (ie, instantaneous current) - /// then the waveforms are expressed as current integrated - /// over each sample bin and thus in units of charge. If - /// the response functions include electronics response - /// then the waveforms are in units of voltage - /// representing the sampling of the output of the FEE - /// amplifiers. - - // fixme: this should be a forward iterator so that it may cal bd.erase() safely to conserve memory - Waveform::realseq_t waveform(int wire) const; - }; - - } // namespace Gen -} // namespace WireCell -#endif /* WIRECELL_IMPACTZIPPER */ diff --git a/gen/inc/WireCellGen/Misconfigure.h b/gen/inc/WireCellGen/Misconfigure.h index 8c6da84a4..1bc59f664 100644 --- a/gen/inc/WireCellGen/Misconfigure.h +++ b/gen/inc/WireCellGen/Misconfigure.h @@ -23,6 +23,8 @@ #include "WireCellIface/IFrameFilter.h" #include "WireCellIface/IConfigurable.h" +#include "WireCellIface/IDFT.h" + #include "WireCellUtil/Waveform.h" #include @@ -45,6 +47,7 @@ namespace WireCell { private: Waveform::realseq_t m_from, m_to; bool m_truncate; + IDFT::pointer m_dft; }; } // namespace Gen } // namespace WireCell diff --git a/gen/inc/WireCellGen/NoiseSource.h b/gen/inc/WireCellGen/NoiseSource.h index 2d3d8840d..50c12c3b5 100644 --- a/gen/inc/WireCellGen/NoiseSource.h +++ b/gen/inc/WireCellGen/NoiseSource.h @@ -11,6 +11,7 @@ #include "WireCellIface/IFrameSource.h" #include "WireCellIface/IConfigurable.h" #include "WireCellIface/IRandom.h" +#include "WireCellIface/IDFT.h" #include "WireCellIface/IAnodePlane.h" #include "WireCellIface/IChannelSpectrum.h" #include "WireCellUtil/Waveform.h" @@ -36,6 +37,7 @@ namespace WireCell { private: IRandom::pointer m_rng; + IDFT::pointer m_dft; IAnodePlane::pointer m_anode; IChannelSpectrum::pointer m_model; double m_time, m_stop, m_readout, m_tick; diff --git a/gen/inc/WireCellGen/PerChannelVariation.h b/gen/inc/WireCellGen/PerChannelVariation.h index 756f3087d..3ce79573e 100644 --- a/gen/inc/WireCellGen/PerChannelVariation.h +++ b/gen/inc/WireCellGen/PerChannelVariation.h @@ -21,6 +21,8 @@ #include "WireCellIface/IFrameFilter.h" #include "WireCellIface/IConfigurable.h" #include "WireCellIface/IChannelResponse.h" +#include "WireCellIface/IDFT.h" + #include "WireCellUtil/Waveform.h" #include @@ -46,6 +48,8 @@ namespace WireCell { int m_nsamples; WireCell::Waveform::realseq_t m_from; bool m_truncate; + IDFT::pointer m_dft; + }; } // namespace Gen } // namespace WireCell diff --git a/gen/inc/WireCellGen/PlaneImpactResponse.h b/gen/inc/WireCellGen/PlaneImpactResponse.h index 8f230e636..d1cb968e9 100644 --- a/gen/inc/WireCellGen/PlaneImpactResponse.h +++ b/gen/inc/WireCellGen/PlaneImpactResponse.h @@ -28,15 +28,18 @@ namespace WireCell { int m_long_waveform_pad; public: - ImpactResponse(int impact, const Waveform::realseq_t& wf, int waveform_pad, + ImpactResponse(int impact, + const Waveform::compseq_t& spectrum, + const Waveform::realseq_t& wf, int waveform_pad, const Waveform::realseq_t& long_wf, int long_waveform_pad) : m_impact(impact) + , m_spectrum(spectrum) , m_waveform(wf) , m_waveform_pad(waveform_pad) , m_long_waveform(long_wf) , m_long_waveform_pad(long_waveform_pad) { - m_spectrum = Waveform::dft(m_waveform); + // m_spectrum = Waveform::dft(m_waveform); } /// Frequency-domain spectrum of response @@ -103,6 +106,7 @@ namespace WireCell { const std::vector& irs() const { return m_ir; } private: + std::string m_frname; std::vector m_short; double m_overall_short_padding; @@ -118,6 +122,7 @@ namespace WireCell { std::vector m_ir; double m_half_extent, m_pitch, m_impact; + std::string m_dftname{"FftwDFT"}; void build_responses(); }; diff --git a/gen/inc/WireCellGen/TruthSmearer.h b/gen/inc/WireCellGen/TruthSmearer.h index 3e35166d7..be700aafa 100644 --- a/gen/inc/WireCellGen/TruthSmearer.h +++ b/gen/inc/WireCellGen/TruthSmearer.h @@ -6,6 +6,7 @@ #include "WireCellIface/IConfigurable.h" #include "WireCellIface/IDuctor.h" +#include "WireCellIface/IDFT.h" #include "WireCellIface/IAnodePlane.h" #include "WireCellIface/IRandom.h" @@ -30,6 +31,7 @@ namespace WireCell { IAnodePlane::pointer m_anode; IRandom::pointer m_rng; + IDFT::pointer m_dft; IDepo::vector m_depos; double m_start_time; diff --git a/gen/inc/WireCellGen/TruthTraceID.h b/gen/inc/WireCellGen/TruthTraceID.h index b6c1c4608..a08af1e7b 100644 --- a/gen/inc/WireCellGen/TruthTraceID.h +++ b/gen/inc/WireCellGen/TruthTraceID.h @@ -10,6 +10,7 @@ #include "WireCellIface/IAnodePlane.h" #include "WireCellIface/IRandom.h" +#include "WireCellIface/IDFT.h" namespace WireCell { namespace Gen { @@ -28,6 +29,7 @@ namespace WireCell { IAnodePlane::pointer m_anode; IRandom::pointer m_rng; + IDFT::pointer m_dft; IDepo::vector m_depos; double m_start_time; diff --git a/gen/src/AddCoherentNoise.cxx b/gen/src/AddCoherentNoise.cxx index 90fdde2a1..9c6a02082 100644 --- a/gen/src/AddCoherentNoise.cxx +++ b/gen/src/AddCoherentNoise.cxx @@ -7,6 +7,8 @@ #include "WireCellUtil/NamedFactory.h" #include "WireCellUtil/FFTBestLength.h" +#include "WireCellAux/DftTools.h" + #include "Noise.h" #include @@ -54,6 +56,7 @@ WireCell::Configuration Gen::AddCoherentNoise::default_configuration() const cfg["random_fluctuation_amplitude"] = m_fluctuation; cfg["period"] = m_period; cfg["normalization"] = m_normalization; + cfg["dft"] = "FftwDFT"; // type-name for the DFT to use return cfg; } @@ -68,6 +71,9 @@ void Gen::AddCoherentNoise::configure(const WireCell::Configuration& cfg) m_fluctuation = get(cfg, "random_fluctuation_amplitude", m_fluctuation); m_normalization = get(cfg, "normalization", m_normalization); + std::string dft_tn = get(cfg, "dft", "FftwDFT"); + m_dft = Factory::find_tn(dft_tn); + m_fft_length = fft_best_length(m_nsamples); gen_elec_resp_default(); @@ -141,7 +147,7 @@ bool Gen::AddCoherentNoise::operator()(const input_pointer& inframe, output_poin noise_freq[i] = tc; } - Waveform::realseq_t wave = WireCell::Waveform::idft(noise_freq); + auto wave = Waveform::real(Aux::inv(m_dft, noise_freq)); // Add signal (be careful to double counting with the incoherent noise) Waveform::increase(wave, intrace->charge()); diff --git a/gen/src/AddNoise.cxx b/gen/src/AddNoise.cxx index c56fea8a0..db306db5f 100644 --- a/gen/src/AddNoise.cxx +++ b/gen/src/AddNoise.cxx @@ -1,5 +1,7 @@ #include "WireCellGen/AddNoise.h" +#include "WireCellAux/DftTools.h" + #include "WireCellIface/SimpleTrace.h" #include "WireCellIface/SimpleFrame.h" @@ -36,6 +38,7 @@ WireCell::Configuration Gen::AddNoise::default_configuration() const cfg["model"] = m_model_tn; cfg["rng"] = m_rng_tn; + cfg["dft"] = "FftwDFT"; // type-name for the DFT to use cfg["nsamples"] = m_nsamples; cfg["replacement_percentage"] = m_rep_percent; return cfg; @@ -45,6 +48,8 @@ void Gen::AddNoise::configure(const WireCell::Configuration& cfg) { m_rng_tn = get(cfg, "rng", m_rng_tn); m_rng = Factory::find_tn(m_rng_tn); + std::string dft_tn = get(cfg, "dft", "FftwDFT"); + m_dft = Factory::find_tn(dft_tn); m_model_tn = get(cfg, "model", m_model_tn); m_model = Factory::find_tn(m_model_tn); m_nsamples = get(cfg, "nsamples", m_nsamples); @@ -66,7 +71,9 @@ bool Gen::AddNoise::operator()(const input_pointer& inframe, output_pointer& out for (const auto& intrace : *inframe->traces()) { int chid = intrace->channel(); const auto& spec = (*m_model)(chid); - Waveform::realseq_t wave = Gen::Noise::generate_waveform(spec, m_rng, m_rep_percent); + auto cspec = Gen::Noise::generate_spectrum(spec, m_rng, m_rep_percent); + auto wave = Waveform::real(Aux::inv(m_dft, cspec)); + // Waveform::realseq_t wave = Gen::Noise::generate_waveform(spec, m_rng, m_rep_percent); wave.resize(m_nsamples, 0); Waveform::increase(wave, intrace->charge()); diff --git a/gen/src/BinnedDiffusion.cxx b/gen/src/BinnedDiffusion.cxx index 4127ba1ee..a855b9c97 100644 --- a/gen/src/BinnedDiffusion.cxx +++ b/gen/src/BinnedDiffusion.cxx @@ -7,9 +7,11 @@ using namespace std; using namespace WireCell; -Gen::BinnedDiffusion::BinnedDiffusion(const Pimpos& pimpos, const Binning& tbins, double nsigma, +Gen::BinnedDiffusion::BinnedDiffusion(const Pimpos& pimpos, const IDFT::pointer& dft, + const Binning& tbins, double nsigma, IRandom::pointer fluctuate, ImpactDataCalculationStrategy calcstrat) : m_pimpos(pimpos) + , m_dft(dft) , m_tbins(tbins) , m_nsigma(nsigma) , m_fluctuate(fluctuate) @@ -127,7 +129,7 @@ Gen::ImpactData::pointer Gen::BinnedDiffusion::impact_data(int bin) const // diff->set_sampling(m_tbins, ib, m_nsigma, 0, m_calcstrat); } - idptr->calculate(m_tbins.nbins()); + idptr->calculate(m_dft, m_tbins.nbins()); return idptr; } diff --git a/gen/src/DepoSplat.cxx b/gen/src/DepoSplat.cxx index eb1ff16d6..47456cb9f 100644 --- a/gen/src/DepoSplat.cxx +++ b/gen/src/DepoSplat.cxx @@ -9,7 +9,7 @@ // from ductor #include "WireCellGen/BinnedDiffusion.h" -#include "WireCellGen/ImpactZipper.h" + #include "WireCellUtil/Units.h" #include "WireCellUtil/Point.h" #include "WireCellUtil/NamedFactory.h" diff --git a/gen/src/DepoTransform.cxx b/gen/src/DepoTransform.cxx index a8beda4d6..023551199 100644 --- a/gen/src/DepoTransform.cxx +++ b/gen/src/DepoTransform.cxx @@ -77,6 +77,8 @@ void Gen::DepoTransform::configure(const WireCell::Configuration& cfg) auto rng_tn = get(cfg, "rng", ""); m_rng = Factory::find_tn(rng_tn); } + std::string dft_tn = get(cfg, "dft", "FftwDFT"); + m_dft = Factory::find_tn(dft_tn); m_readout_time = get(cfg, "readout_time", m_readout_time); m_tick = get(cfg, "tick", m_tick); @@ -132,6 +134,9 @@ WireCell::Configuration Gen::DepoTransform::default_configuration() const /// Plane impact responses cfg["pirs"] = Json::arrayValue; + // type-name for the DFT to use + cfg["dft"] = "FftwDFT"; + return cfg; } @@ -203,7 +208,7 @@ bool Gen::DepoTransform::operator()(const input_pointer& in, output_pointer& out auto& wires = plane->wires(); auto pir = m_pirs.at(iplane); - Gen::ImpactTransform transform(pir, bindiff); + Gen::ImpactTransform transform(pir, m_dft, bindiff); const int nwires = pimpos->region_binning().nbins(); for (int iwire = 0; iwire < nwires; ++iwire) { diff --git a/gen/src/DepoZipper.cxx b/gen/src/DepoZipper.cxx deleted file mode 100644 index 07c0eeb5e..000000000 --- a/gen/src/DepoZipper.cxx +++ /dev/null @@ -1,187 +0,0 @@ -#include "WireCellGen/DepoZipper.h" -#include "WireCellGen/ImpactZipper.h" -#include "WireCellUtil/NamedFactory.h" -#include "WireCellIface/IAnodePlane.h" -#include "WireCellIface/SimpleTrace.h" -#include "WireCellIface/SimpleFrame.h" -#include "WireCellGen/BinnedDiffusion.h" -#include "WireCellGen/ImpactZipper.h" -#include "WireCellUtil/Units.h" -#include "WireCellUtil/Point.h" - -WIRECELL_FACTORY(DepoZipper, WireCell::Gen::DepoZipper, WireCell::IDepoFramer, WireCell::IConfigurable) - -using namespace WireCell; -using namespace std; - -Gen::DepoZipper::DepoZipper() - : m_start_time(0.0 * units::ns) - , m_readout_time(5.0 * units::ms) - , m_tick(0.5 * units::us) - , m_drift_speed(1.0 * units::mm / units::us) - , m_nsigma(3.0) - , m_frame_count(0) -{ -} - -Gen::DepoZipper::~DepoZipper() {} - -void Gen::DepoZipper::configure(const WireCell::Configuration& cfg) -{ - auto anode_tn = get(cfg, "anode", ""); - m_anode = Factory::find_tn(anode_tn); - - m_nsigma = get(cfg, "nsigma", m_nsigma); - bool fluctuate = get(cfg, "fluctuate", false); - m_rng = nullptr; - if (fluctuate) { - auto rng_tn = get(cfg, "rng", ""); - m_rng = Factory::find_tn(rng_tn); - } - - m_readout_time = get(cfg, "readout_time", m_readout_time); - m_tick = get(cfg, "tick", m_tick); - m_start_time = get(cfg, "start_time", m_start_time); - m_drift_speed = get(cfg, "drift_speed", m_drift_speed); - m_frame_count = get(cfg, "first_frame_number", m_frame_count); - - auto jpirs = cfg["pirs"]; - if (jpirs.isNull() or jpirs.empty()) { - THROW(ValueError() << errmsg{"Gen::Ductor: must configure with some plane impact response components"}); - } - m_pirs.clear(); - for (auto jpir : jpirs) { - auto tn = jpir.asString(); - auto pir = Factory::find_tn(tn); - m_pirs.push_back(pir); - } -} -WireCell::Configuration Gen::DepoZipper::default_configuration() const -{ - Configuration cfg; - - /// How many Gaussian sigma due to diffusion to keep before truncating. - put(cfg, "nsigma", m_nsigma); - - /// Whether to fluctuate the final Gaussian deposition. - put(cfg, "fluctuate", false); - - /// The open a gate. This is actually a "readin" time measured at - /// the input ("reference") plane. - put(cfg, "start_time", m_start_time); - - /// The time span for each readout. This is actually a "readin" - /// time span measured at the input ("reference") plane. - put(cfg, "readout_time", m_readout_time); - - /// The sample period - put(cfg, "tick", m_tick); - - /// The nominal speed of drifting electrons - put(cfg, "drift_speed", m_drift_speed); - - /// Allow for a custom starting frame number - put(cfg, "first_frame_number", m_frame_count); - - /// Name of component providing the anode plane. - put(cfg, "anode", ""); - /// Name of component providing the anode pseudo random number generator. - put(cfg, "rng", ""); - - /// Plane impact responses - cfg["pirs"] = Json::arrayValue; - - return cfg; -} - -bool Gen::DepoZipper::operator()(const input_pointer& in, output_pointer& out) -{ - if (!in) { - out = nullptr; - cerr << "Gen::DepoZipper: EOS\n"; - return true; - } - - auto depos = in->depos(); - - Binning tbins(m_readout_time / m_tick, m_start_time, m_start_time + m_readout_time); - ITrace::vector traces; - for (auto face : m_anode->faces()) { - // Select the depos which are in this face's sensitive volume - IDepo::vector face_depos, dropped_depos; - auto bb = face->sensitive(); - if (bb.empty()) { - cerr << "Gen::DepoZipper anode:" << m_anode->ident() << " face:" << face->ident() - << " is marked insensitive, skipping\n"; - continue; - } - - for (auto depo : (*depos)) { - if (bb.inside(depo->pos())) { - face_depos.push_back(depo); - } - else { - dropped_depos.push_back(depo); - } - } - - if (face_depos.size()) { - auto ray = bb.bounds(); - cerr << "Gen::Ductor: anode:" << m_anode->ident() << " face:" << face->ident() << ": processing " - << face_depos.size() << " depos spanning: t:[" << face_depos.front()->time() / units::ms << ", " - << face_depos.back()->time() / units::ms << "]ms, bb: " << ray.first / units::cm << " --> " - << ray.second / units::cm << "cm\n"; - } - if (dropped_depos.size()) { - auto ray = bb.bounds(); - cerr << "Gen::Ductor: anode:" << m_anode->ident() << " face:" << face->ident() << ": dropped " - << dropped_depos.size() << " depos spanning: t:[" << dropped_depos.front()->time() / units::ms << ", " - << dropped_depos.back()->time() / units::ms << "]ms, outside bb: " << ray.first / units::cm << " --> " - << ray.second / units::cm << "cm\n"; - } - - int iplane = -1; - for (auto plane : face->planes()) { - ++iplane; - - const Pimpos* pimpos = plane->pimpos(); - - Binning tbins(m_readout_time / m_tick, m_start_time, m_start_time + m_readout_time); - - Gen::BinnedDiffusion bindiff(*pimpos, tbins, m_nsigma, m_rng); - for (auto depo : face_depos) { - bindiff.add(depo, depo->extent_long() / m_drift_speed, depo->extent_tran()); - } - - auto& wires = plane->wires(); - - auto pir = m_pirs.at(iplane); - Gen::ImpactZipper zipper(pir, bindiff); - - const int nwires = pimpos->region_binning().nbins(); - for (int iwire = 0; iwire < nwires; ++iwire) { - auto wave = zipper.waveform(iwire); - - auto mm = Waveform::edge(wave); - if (mm.first == (int) wave.size()) { // all zero - continue; - } - - int chid = wires[iwire]->channel(); - int tbin = mm.first; - - // std::cout << mm.first << " "<< mm.second << std::endl; - - ITrace::ChargeSequence charge(wave.begin() + mm.first, wave.begin() + mm.second); - auto trace = make_shared(chid, tbin, charge); - traces.push_back(trace); - } - } - } - - auto frame = make_shared(m_frame_count, m_start_time, traces, m_tick); - cerr << "Gen::DepoZipper: make frame " << m_frame_count << "\n"; - ++m_frame_count; - out = frame; - return true; -} diff --git a/gen/src/Ductor.cxx b/gen/src/Ductor.cxx deleted file mode 100644 index 9dcd8f264..000000000 --- a/gen/src/Ductor.cxx +++ /dev/null @@ -1,289 +0,0 @@ -#include "WireCellGen/Ductor.h" -#include "WireCellGen/BinnedDiffusion.h" -#include "WireCellGen/ImpactZipper.h" -#include "WireCellUtil/Units.h" -#include "WireCellUtil/Point.h" -#include "WireCellUtil/NamedFactory.h" -#include "WireCellIface/SimpleTrace.h" -#include "WireCellIface/SimpleFrame.h" - -#include - -WIRECELL_FACTORY(Ductor, WireCell::Gen::Ductor, WireCell::IDuctor, WireCell::IConfigurable) - -using namespace std; -using namespace WireCell; - -Gen::Ductor::Ductor() - : m_anode_tn("AnodePlane") - , m_rng_tn("Random") - , m_start_time(0.0 * units::ns) - , m_readout_time(5.0 * units::ms) - , m_tick(0.5 * units::us) - , m_drift_speed(1.0 * units::mm / units::us) - , m_nsigma(3.0) - , m_fluctuate(true) - , m_mode("continuous") - , m_frame_count(0) - , l(Log::logger("sim")) -{ -} - -WireCell::Configuration Gen::Ductor::default_configuration() const -{ - Configuration cfg; - - /// How many Gaussian sigma due to diffusion to keep before truncating. - put(cfg, "nsigma", m_nsigma); - - /// Whether to fluctuate the final Gaussian deposition. - put(cfg, "fluctuate", m_fluctuate); - - /// The initial time for this ductor - put(cfg, "start_time", m_start_time); - - /// The time span for each readout. - put(cfg, "readout_time", m_readout_time); - - /// The sample period - put(cfg, "tick", m_tick); - - /// If false then determine start time of each readout based on the - /// input depos. This option is useful when running WCT sim on a - /// source of depos which have already been "chunked" in time. If - /// true then this Ductor will continuously simulate all time in - /// "readout_time" frames leading to empty frames in the case of - /// some readout time with no depos. - put(cfg, "continuous", true); - - /// Fixed mode simply reads out the same time window all the time. - /// It implies discontinuous (continuous == false). - put(cfg, "fixed", false); - - /// The nominal speed of drifting electrons - put(cfg, "drift_speed", m_drift_speed); - - /// Allow for a custom starting frame number - put(cfg, "first_frame_number", m_frame_count); - - /// Name of component providing the anode plane. - put(cfg, "anode", m_anode_tn); - put(cfg, "rng", m_rng_tn); - - cfg["pirs"] = Json::arrayValue; - /// don't set here so user must, but eg: - // cfg["pirs"][0] = "PlaneImpactResponseU"; - // cfg["pirs"][1] = "PlaneImpactResponseV"; - // cfg["pirs"][2] = "PlaneImpactResponseW"; - - // Tag to use for frame and traces will get this tag + the anode - // ID. - cfg["tag"] = "ductor"; - - return cfg; -} - -void Gen::Ductor::configure(const WireCell::Configuration& cfg) -{ - m_anode_tn = get(cfg, "anode", m_anode_tn); - m_anode = Factory::find_tn(m_anode_tn); - - m_nsigma = get(cfg, "nsigma", m_nsigma); - bool continuous = get(cfg, "continuous", true); - bool fixed = get(cfg, "fixed", false); - - m_mode = "continuous"; - if (fixed) { - m_mode = "fixed"; - } - else if (!continuous) { - m_mode = "discontinuous"; - } - - m_fluctuate = get(cfg, "fluctuate", m_fluctuate); - m_rng = nullptr; - if (m_fluctuate) { - m_rng_tn = get(cfg, "rng", m_rng_tn); - m_rng = Factory::find_tn(m_rng_tn); - } - - m_readout_time = get(cfg, "readout_time", m_readout_time); - m_tick = get(cfg, "tick", m_tick); - m_start_time = get(cfg, "start_time", m_start_time); - m_drift_speed = get(cfg, "drift_speed", m_drift_speed); - m_frame_count = get(cfg, "first_frame_number", m_frame_count); - - auto jpirs = cfg["pirs"]; - if (jpirs.isNull() or jpirs.empty()) { - l->critical("must configure with some plane impace response components"); - THROW(ValueError() << errmsg{"Gen::Ductor: must configure with some plane impact response components"}); - } - m_pirs.clear(); - for (auto jpir : jpirs) { - auto tn = jpir.asString(); - auto pir = Factory::find_tn(tn); - m_pirs.push_back(pir); - } - - m_tag = get(cfg, "tag", "ductor"); - - l->debug("Ductor tagging {}, AnodePlane: {}, mode: {}, fluctuate: {}, time start: {} ms, readout time: {} ms, frame start: {}", - m_tag, - m_anode_tn, m_mode, (m_fluctuate ? "on" : "off"), m_start_time / units::ms, m_readout_time / units::ms, - m_frame_count); -} - -ITrace::vector Gen::Ductor::process_face(IAnodeFace::pointer face, const IDepo::vector& face_depos) -{ - ITrace::vector traces; - - int iplane = -1; - for (auto plane : face->planes()) { - ++iplane; - - const Pimpos* pimpos = plane->pimpos(); - - Binning tbins(m_readout_time / m_tick, m_start_time, m_start_time + m_readout_time); - - Gen::BinnedDiffusion bindiff(*pimpos, tbins, m_nsigma, m_rng); - for (auto depo : face_depos) { - bindiff.add(depo, depo->extent_long() / m_drift_speed, depo->extent_tran()); - } - - auto& wires = plane->wires(); - - auto pir = m_pirs.at(iplane); - Gen::ImpactZipper zipper(pir, bindiff); - - const int nwires = pimpos->region_binning().nbins(); - for (int iwire = 0; iwire < nwires; ++iwire) { - auto wave = zipper.waveform(iwire); - - auto mm = Waveform::edge(wave); - if (mm.first == (int) wave.size()) { // all zero - continue; - } - - int chid = wires[iwire]->channel(); - int tbin = mm.first; - - ITrace::ChargeSequence charge(wave.begin() + mm.first, wave.begin() + mm.second); - auto trace = make_shared(chid, tbin, charge); - traces.push_back(trace); - } - } - return traces; -} - -void Gen::Ductor::process(output_queue& frames) -{ - ITrace::vector traces; - - for (auto face : m_anode->faces()) { - // Select the depos which are in this face's sensitive volume - IDepo::vector face_depos, dropped_depos; - auto bb = face->sensitive(); - if (bb.empty()) { - l->debug("anode: {} face: {} is marked insensitive, skipping", m_anode->ident(), face->ident()); - continue; - } - - for (auto depo : m_depos) { - if (bb.inside(depo->pos())) { - face_depos.push_back(depo); - } - else { - dropped_depos.push_back(depo); - } - } - - if (face_depos.size()) { - auto ray = bb.bounds(); - l->debug( - "anode: {}, face: {}, processing {} depos spanning " - "t:[{},{}]ms, bb:[{}-->{}]cm", - m_anode->ident(), face->ident(), face_depos.size(), face_depos.front()->time() / units::ms, - face_depos.back()->time() / units::ms, ray.first / units::cm, ray.second / units::cm); - } - if (dropped_depos.size()) { - auto ray = bb.bounds(); - l->debug( - "anode: {}, face: {}, dropped {} depos spanning " - "t:[{},{}]ms, outside bb:[{}-->{}]cm", - m_anode->ident(), face->ident(), dropped_depos.size(), dropped_depos.front()->time() / units::ms, - dropped_depos.back()->time() / units::ms, ray.first / units::cm, ray.second / units::cm); - } - - auto newtraces = process_face(face, face_depos); - traces.insert(traces.end(), newtraces.begin(), newtraces.end()); - } - - auto frame = make_shared(m_frame_count, m_start_time, traces, m_tick); - IFrame::trace_list_t indices(traces.size()); - for (size_t ind = 0; ind < traces.size(); ++ind) { - indices[ind] = ind; - } - frame->tag_traces(m_tag + std::to_string(m_anode->ident()), indices); - frame->tag_frame(m_tag); - frames.push_back(frame); - l->debug("made frame: {} with {} traces @ {}ms", m_frame_count, traces.size(), m_start_time / units::ms); - - // fixme: what about frame overflow here? If the depos extend - // beyond the readout where does their info go? 2nd order, - // diffusion and finite field response can cause depos near the - // end of the readout to have some portion of their waveforms - // lost? - m_depos.clear(); - - if (m_mode == "continuous") { - m_start_time += m_readout_time; - } - - ++m_frame_count; -} - -// Return true if ready to start processing and capture start time if -// in continuous mode. -bool Gen::Ductor::start_processing(const input_pointer& depo) -{ - if (!depo) { - return true; - } - - if (m_mode == "fixed") { - // fixed mode waits until EOS - return false; - } - - if (m_mode == "discontinuous") { - // discontinuous mode sets start time on first depo. - if (m_depos.empty()) { - m_start_time = depo->time(); - return false; - } - } - - // continuous and discontinuous modes follow Just Enough - // Processing(TM) strategy. - - // Note: we use this depo time even if it may not actually be - // inside our sensitive volume. - bool ok = depo->time() > m_start_time + m_readout_time; - return ok; -} - -bool Gen::Ductor::operator()(const input_pointer& depo, output_queue& frames) -{ - if (start_processing(depo)) { - process(frames); - } - - if (depo) { - m_depos.push_back(depo); - } - else { - frames.push_back(nullptr); - } - - return true; -} diff --git a/gen/src/EmpiricalNoiseModel.cxx b/gen/src/EmpiricalNoiseModel.cxx index 85df675a3..cfc87dc19 100644 --- a/gen/src/EmpiricalNoiseModel.cxx +++ b/gen/src/EmpiricalNoiseModel.cxx @@ -9,6 +9,8 @@ #include "WireCellUtil/NamedFactory.h" #include "WireCellUtil/FFTBestLength.h" +#include "WireCellAux/DftTools.h" + #include // debug WIRECELL_FACTORY(EmpiricalNoiseModel, @@ -46,8 +48,6 @@ Gen::EmpiricalNoiseModel::~EmpiricalNoiseModel() {} void Gen::EmpiricalNoiseModel::gen_elec_resp_default() { - // double shaping[5]={1,1.1,2,2.2,3}; // us - // calculate the frequencies ... m_elec_resp_freq.resize(m_fft_length, 0); for (unsigned int i = 0; i != m_elec_resp_freq.size(); i++) { @@ -59,21 +59,7 @@ void Gen::EmpiricalNoiseModel::gen_elec_resp_default() m_elec_resp_freq.at(i) = (m_elec_resp_freq.size() - i) / (m_elec_resp_freq.size() * 1.0) * 1. / m_period; // the second half is useless ... } - - // if (m_elec_resp_freq.at(i) > 1./m_period / 2.){ - // m_elec_resp_freq.resize(i); - // break; - // } } - - // for (int i=0;i!=5;i++){ - // Response::ColdElec elec_resp(1, shaping[i]); // default at 1 mV/fC - // auto sig = elec_resp.generate(WireCell::Waveform::Domain(0, m_fft_length*m_period), m_fft_length); - // auto filt = Waveform::dft(sig); - // int nconfig = shaping[i]/ 0.1; - // auto ele_resp_amp = Waveform::magnitude(filt); - // m_elec_resp_cache[nconfig] = ele_resp_amp; - // } } WireCell::Configuration Gen::EmpiricalNoiseModel::default_configuration() const @@ -89,6 +75,7 @@ WireCell::Configuration Gen::EmpiricalNoiseModel::default_configuration() const // cfg["gain_scale"] = m_gres; // cfg["freq_scale"] = m_fres; cfg["anode"] = m_anode_tn; // name of IAnodePlane component + cfg["dft"] = "FftwDFT"; // type-name for the DFT to use return cfg; } @@ -161,6 +148,9 @@ void Gen::EmpiricalNoiseModel::configure(const WireCell::Configuration& cfg) m_spectra_file = get(cfg, "spectra_file", m_spectra_file); + std::string dft_tn = get(cfg, "dft", "FftwDFT"); + m_dft = Factory::find_tn(dft_tn); + m_nsamples = get(cfg, "nsamples", m_nsamples); m_fft_length = fft_best_length(m_nsamples); // m_fft_length = m_nsamples; @@ -352,7 +342,7 @@ const IChannelSpectrum::amplitude_t& Gen::EmpiricalNoiseModel::operator()(int ch if (resp1 == m_elec_resp_cache.end()) { Response::ColdElec elec_resp(10, ch_shaping); // default at 1 mV/fC auto sig = elec_resp.generate(WireCell::Waveform::Domain(0, m_fft_length * m_period), m_fft_length); - auto filt = Waveform::dft(sig); + auto filt = Aux::fwd_r2c(m_dft, sig); auto ele_resp_amp = Waveform::magnitude(filt); ele_resp_amp.resize(m_elec_resp_freq.size()); @@ -365,7 +355,7 @@ const IChannelSpectrum::amplitude_t& Gen::EmpiricalNoiseModel::operator()(int ch if (resp2 == m_elec_resp_cache.end()) { Response::ColdElec elec_resp(10, db_shaping); // default at 1 mV/fC auto sig = elec_resp.generate(WireCell::Waveform::Domain(0, m_fft_length * m_period), m_fft_length); - auto filt = Waveform::dft(sig); + auto filt = Aux::fwd_r2c(m_dft, sig); auto ele_resp_amp = Waveform::magnitude(filt); ele_resp_amp.resize(m_elec_resp_freq.size()); diff --git a/gen/src/ImpactData.cxx b/gen/src/ImpactData.cxx index a3ed125b4..09c8c093a 100644 --- a/gen/src/ImpactData.cxx +++ b/gen/src/ImpactData.cxx @@ -1,5 +1,7 @@ #include "WireCellGen/ImpactData.h" +#include "WireCellAux/DftTools.h" + #include // debugging using namespace WireCell; @@ -19,7 +21,7 @@ Waveform::realseq_t& Gen::ImpactData::weightform() const { return m_weights; } Waveform::compseq_t& Gen::ImpactData::weight_spectrum() const { return m_weight_spectrum; } -void Gen::ImpactData::calculate(int nticks) const +void Gen::ImpactData::calculate(const IDFT::pointer& dft, int nticks) const { if (m_waveform.size() > 0) { return; @@ -54,8 +56,8 @@ void Gen::ImpactData::calculate(int nticks) const } } - m_spectrum = Waveform::dft(m_waveform); - m_weight_spectrum = Waveform::dft(m_weights); + m_spectrum = Aux::fwd_r2c(dft, m_waveform); + m_weight_spectrum = Aux::fwd_r2c(dft, m_weights); } // std::pair Gen::ImpactData::strip() const diff --git a/gen/src/ImpactTransform.cxx b/gen/src/ImpactTransform.cxx index 74739b453..183a44fc8 100644 --- a/gen/src/ImpactTransform.cxx +++ b/gen/src/ImpactTransform.cxx @@ -1,33 +1,29 @@ #include "WireCellGen/ImpactTransform.h" + +#include "WireCellAux/DftTools.h" + #include "WireCellUtil/Testing.h" #include "WireCellUtil/FFTBestLength.h" #include "WireCellUtil/Exceptions.h" + #include // debugging. using namespace std; using namespace WireCell; Gen::ImpactTransform::ImpactTransform(IPlaneImpactResponse::pointer pir, + const IDFT::pointer& dft, BinnedDiffusion_transform& bd) : m_pir(pir) + , m_dft(dft) , m_bd(bd) { - // for (int i=0;i!=210;i++){ - // double pos = -31.5 + 0.3*i+1e-9;0 - // m_pir->closest(pos); - // } - // arrange the field response (210 in total, pitch_range/impact) // number of wires nwires ... m_num_group = std::round(m_pir->pitch() / m_pir->impact()) + 1; // 11 m_num_pad_wire = std::round((m_pir->nwires() - 1) / 2.); // 10 for wires, 5 for PCB strips const auto pimpos = m_bd.pimpos(); - // const int nsamples = m_bd.tbins().nbins(); - // const auto rb = pimpos.region_binning(); - // const int nwires = rb.nbins(); - - // //std::cerr << "ImpactTransform: num_group:" << m_num_group << " num_pad_wire:" << m_num_pad_wire << std::endl; for (int i = 0; i != m_num_group; i++) { @@ -43,10 +39,6 @@ Gen::ImpactTransform::ImpactTransform(IPlaneImpactResponse::pointer pir, //std::cerr << "ImpactTransform: " << rel_cen_imp_pos << std::endl; for (int j = 0; j != m_pir->nwires(); j++) { - // std::cerr << "ImpactTransform: " - // << i << " " << j << " " - // << rel_cen_imp_pos - (j-m_num_pad_wire)*m_pir->pitch()<< " " - // << std::endl; try { map_resp[j - m_num_pad_wire] = m_pir->closest(rel_cen_imp_pos - (j - m_num_pad_wire) * m_pir->pitch()); @@ -68,37 +60,21 @@ Gen::ImpactTransform::ImpactTransform(IPlaneImpactResponse::pointer pir, Waveform::compseq_t response_spectrum = map_resp[j - m_num_pad_wire]->spectrum(); - //response_spectrum.size() << std::endl; } - // std::cout << m_vec_impact.back() << std::endl; - // std::cout << rel_cen_imp_pos << std::endl; - // std::cout << map_resp.size() << std::endl; - m_vec_map_resp.push_back(map_resp); - // Eigen::SparseMatrix *mat = new Eigen::SparseMatrix(nsamples,nwires); - // mat.reserve(Eigen::VectorXi::Constant(nwires,1000)); - // m_vec_spmatrix.push_back(mat); + m_vec_map_resp.push_back(map_resp); std::vector > vec_charge; // ch, time, charge m_vec_vec_charge.push_back(vec_charge); } - // m_bd.get_charge_matrix(m_vec_spmatrix, m_vec_impact); - // std::cout << nwires << " " << nsamples << std::endl; - // now work on the charge part ... // trying to sampling ... m_bd.get_charge_vec(m_vec_vec_charge, m_vec_impact); // std::cout << nwires << " " << nsamples << std::endl; - // for (size_t i=0;i!=m_vec_vec_charge.size();i++){ - // std::cout << m_vec_vec_charge[i].size() << std::endl; - // } - // length and width ... - // - // std::cout << nwires << " " << nsamples << std::endl; std::pair impact_range = m_bd.impact_bin_range(m_bd.get_nsigma()); std::pair time_range = m_bd.time_bin_range(m_bd.get_nsigma()); @@ -120,15 +96,6 @@ Gen::ImpactTransform::ImpactTransform(IPlaneImpactResponse::pointer pir, int npad_wire = 0; const size_t ntotal_wires = fft_best_length(end_ch - start_ch + 2 * m_num_pad_wire, 1); - // pow(2,std::ceil(log(end_ch - start_ch + 2 * m_num_pad_wire)/log(2))); - // if (nwires == 2400){ - // if (ntotal_wires > 2500) - // ntotal_wires = 2500; - // }else if (nwires ==3456){ - // if (ntotal_wires > 3600) - // ntotal_wires = 3600; - // npad_wire=72; //3600 - //} npad_wire = (ntotal_wires - end_ch + start_ch) / 2; m_start_ch = start_ch - npad_wire; m_end_ch = end_ch + npad_wire; @@ -138,29 +105,14 @@ Gen::ImpactTransform::ImpactTransform(IPlaneImpactResponse::pointer pir, int npad_time = m_pir->closest(0)->waveform_pad(); const size_t ntotal_ticks = fft_best_length(end_tick - start_tick + npad_time); - // pow(2,std::ceil(log(end_tick - start_tick + npad_time)/log(2))); - // if (ntotal_ticks >9800 && nsamples <9800 && nsamples >9550) - // ntotal_ticks = 9800; npad_time = ntotal_ticks - end_tick + start_tick; m_start_tick = start_tick; m_end_tick = end_tick + npad_time; - // m_end_tick = 16384;//nsamples; - // m_start_tick = 0; - // // std::cout << m_start_tick << " " << m_end_tick << std::endl; - // int npad_time = 0; - // int ntotal_ticks = pow(2,std::ceil(log(nsamples + npad_time)/log(2))); - // if (ntotal_ticks >9800 && nsamples <9800) - // ntotal_ticks = 9800 - // npad_time = ntotal_ticks - nsamples; - // m_start_tick = 0; - // m_end_tick = ntotal_ticks; - Array::array_xxc acc_data_f_w = Array::array_xxc::Zero(end_ch - start_ch + 2 * npad_wire, m_end_tick - m_start_tick); int num_double = (m_vec_vec_charge.size() - 1) / 2; - // int num_double = (m_vec_spmatrix.size()-1)/2; // speed up version , first five for (int i = 0; i != num_double; i++) { @@ -178,15 +130,6 @@ Gen::ImpactTransform::ImpactTransform(IPlaneImpactResponse::pointer pir, m_vec_vec_charge.at(i).clear(); m_vec_vec_charge.at(i).shrink_to_fit(); - // useing matrix form ... - // for (int k=0; kouterSize(); ++k) - // for (Eigen::SparseMatrix::InnerIterator it(*m_vec_spmatrix.at(i),k); it; ++it){ - // c_data(it.col()+npad_wire-start_ch,it.row()-m_start_tick) = it.value(); - // } - // delete m_vec_spmatrix.at(i); - // //m_vec_spmatrix.at(i).setZero(); - // //m_vec_spmatrix.at(i).resize(0,0); - // fill reverse order int ii = num_double * 2 - i; for (size_t j = 0; j != m_vec_vec_charge.at(ii).size(); j++) { @@ -197,18 +140,10 @@ Gen::ImpactTransform::ImpactTransform(IPlaneImpactResponse::pointer pir, // std::cout << ii << " " << m_vec_vec_charge.at(ii).size() << std::endl; m_vec_vec_charge.at(ii).clear(); m_vec_vec_charge.at(ii).shrink_to_fit(); - // for (int k=0; kouterSize(); ++k) - // for (Eigen::SparseMatrix::InnerIterator it(*m_vec_spmatrix.at(ii),k); it; ++it){ - // c_data(it.col()+npad_wire-start_ch,it.row()-m_start_tick) = it.value(); - // } - // delete m_vec_spmatrix.at(ii); - // // m_vec_spmatrix.at(ii).setZero(); - // //m_vec_spmatrix.at(ii).resize(0,0); // Do FFT on time - c_data = Array::dft_cc(c_data, 0); // Do FFT on wire - c_data = Array::dft_cc(c_data, 1); + c_data = Aux::fwd(m_dft, c_data); // std::cout << i << std::endl; { @@ -217,7 +152,7 @@ Gen::ImpactTransform::ImpactTransform(IPlaneImpactResponse::pointer pir, { Waveform::compseq_t rs1 = m_vec_map_resp.at(i)[0]->spectrum(); // do a inverse FFT - Waveform::realseq_t rs1_t = Waveform::idft(rs1); + Waveform::realseq_t rs1_t = Aux::inv_c2r(m_dft, rs1); // pick the first xxx ticks Waveform::realseq_t rs1_reduced(m_end_tick - m_start_tick, 0); for (int icol = 0; icol != m_end_tick - m_start_tick; icol++) { @@ -225,7 +160,7 @@ Gen::ImpactTransform::ImpactTransform(IPlaneImpactResponse::pointer pir, rs1_reduced.at(icol) = rs1_t[icol]; } // do a FFT - rs1 = Waveform::dft(rs1_reduced); + rs1 = Aux::fwd_r2c(m_dft, rs1_reduced); for (int icol = 0; icol != m_end_tick - m_start_tick; icol++) { resp_f_w(0, icol) = rs1[icol]; @@ -234,21 +169,21 @@ Gen::ImpactTransform::ImpactTransform(IPlaneImpactResponse::pointer pir, for (int irow = 0; irow != m_num_pad_wire; irow++) { Waveform::compseq_t rs1 = m_vec_map_resp.at(i)[irow + 1]->spectrum(); - Waveform::realseq_t rs1_t = Waveform::idft(rs1); + Waveform::realseq_t rs1_t = Aux::inv_c2r(m_dft, rs1); Waveform::realseq_t rs1_reduced(m_end_tick - m_start_tick, 0); for (int icol = 0; icol != m_end_tick - m_start_tick; icol++) { if (icol >= int(rs1_t.size())) break; rs1_reduced.at(icol) = rs1_t[icol]; } - rs1 = Waveform::dft(rs1_reduced); + rs1 = Aux::fwd_r2c(m_dft, rs1_reduced); Waveform::compseq_t rs2 = m_vec_map_resp.at(i)[-irow - 1]->spectrum(); - Waveform::realseq_t rs2_t = Waveform::idft(rs2); + Waveform::realseq_t rs2_t = Aux::inv_c2r(m_dft, rs2); Waveform::realseq_t rs2_reduced(m_end_tick - m_start_tick, 0); for (int icol = 0; icol != m_end_tick - m_start_tick; icol++) { if (icol >= int(rs2_t.size())) break; rs2_reduced.at(icol) = rs2_t[icol]; } - rs2 = Waveform::dft(rs2_reduced); + rs2 = Aux::fwd_r2c(m_dft, rs2_reduced); for (int icol = 0; icol != m_end_tick - m_start_tick; icol++) { resp_f_w(irow + 1, icol) = rs1[icol]; resp_f_w(end_ch - start_ch - 1 - irow + 2 * npad_wire, icol) = rs2[icol]; @@ -257,13 +192,15 @@ Gen::ImpactTransform::ImpactTransform(IPlaneImpactResponse::pointer pir, // std::cout << i << std::endl; // Do FFT on wire for response // slight larger - resp_f_w = Array::dft_cc(resp_f_w, 1); // Now becomes the f and f in both time and wire domain ... + // Now becomes the f and f in both time and wire domain ... + resp_f_w = Aux::fwd(m_dft, resp_f_w, 0); + // multiply them together c_data = c_data * resp_f_w; } // Do inverse FFT on wire - c_data = Array::idft_cc(c_data, 1); + c_data = Aux::inv(m_dft, c_data, 0); // Add to wire result in frequency acc_data_f_w += c_data; @@ -290,18 +227,12 @@ Gen::ImpactTransform::ImpactTransform(IPlaneImpactResponse::pointer pir, // std::cout << i << " " << m_vec_vec_charge.at(i).size() << std::endl; m_vec_vec_charge.at(i).clear(); m_vec_vec_charge.at(i).shrink_to_fit(); - // for (int k=0; kouterSize(); ++k) - // for (Eigen::SparseMatrix::InnerIterator it(*m_vec_spmatrix.at(i),k); it; ++it){ - // data_t_w(it.col()+npad_wire-start_ch,it.row()-m_start_tick) = it.value(); - // } - // delete m_vec_spmatrix.at(i); - // // m_vec_spmatrix.at(i).setZero(); - // // m_vec_spmatrix.at(i).resize(0,0); // Do FFT on time - data_f_w = Array::dft_rc(data_t_w, 0); // Do FFT on wire - data_f_w = Array::dft_cc(data_f_w, 1); + data_f_w = data_t_w.cast(); + data_f_w = Aux::fwd(m_dft, data_f_w); + } { @@ -310,16 +241,9 @@ Gen::ImpactTransform::ImpactTransform(IPlaneImpactResponse::pointer pir, { Waveform::compseq_t rs1 = m_vec_map_resp.at(i)[0]->spectrum(); - // Array::array_xxc temp_resp_f_w = Array::array_xxc::Zero(2*m_num_pad_wire+1,nsamples); - // for (int icol = 0; icol != nsamples; icol++){ - // temp_resp_f_w(0,icol) = rs1[icol]; - // } - // Array::array_xxf temp_resp_t_w = - // Array::idft_cr(temp_resp_f_w,0).block(0,0,2*m_num_pad_wire+1,m_end_tick-m_start_tick); temp_resp_f_w - // = Array::dft_rc(temp_resp_t_w,0); // do a inverse FFT - Waveform::realseq_t rs1_t = Waveform::idft(rs1); + Waveform::realseq_t rs1_t = Aux::inv_c2r(m_dft, rs1); // pick the first xxx ticks Waveform::realseq_t rs1_reduced(m_end_tick - m_start_tick, 0); // std::cout << rs1.size() << " " << nsamples << " " << m_end_tick << " " << m_start_tick << std::endl; @@ -329,7 +253,7 @@ Gen::ImpactTransform::ImpactTransform(IPlaneImpactResponse::pointer pir, // std::cout << icol << " " << rs1_t[icol] << std::endl; } // do a FFT - rs1 = Waveform::dft(rs1_reduced); + rs1 = Aux::fwd_r2c(m_dft, rs1_reduced); for (int icol = 0; icol != m_end_tick - m_start_tick; icol++) { // std::cout << icol << " " << rs1[icol] << " " << temp_resp_f_w(0,icol) << std::endl; @@ -338,131 +262,45 @@ Gen::ImpactTransform::ImpactTransform(IPlaneImpactResponse::pointer pir, } for (int irow = 0; irow != m_num_pad_wire; irow++) { Waveform::compseq_t rs1 = m_vec_map_resp.at(i)[irow + 1]->spectrum(); - Waveform::realseq_t rs1_t = Waveform::idft(rs1); + Waveform::realseq_t rs1_t = Aux::inv_c2r(m_dft, rs1); Waveform::realseq_t rs1_reduced(m_end_tick - m_start_tick, 0); for (int icol = 0; icol != m_end_tick - m_start_tick; icol++) { if (icol >= int(rs1_t.size())) break; rs1_reduced.at(icol) = rs1_t[icol]; } - rs1 = Waveform::dft(rs1_reduced); + rs1 = Aux::fwd_r2c(m_dft, rs1_reduced); Waveform::compseq_t rs2 = m_vec_map_resp.at(i)[-irow - 1]->spectrum(); - Waveform::realseq_t rs2_t = Waveform::idft(rs2); + Waveform::realseq_t rs2_t = Aux::inv_c2r(m_dft, rs2); Waveform::realseq_t rs2_reduced(m_end_tick - m_start_tick, 0); for (int icol = 0; icol != m_end_tick - m_start_tick; icol++) { if (icol >= int(rs2_t.size())) break; rs2_reduced.at(icol) = rs2_t[icol]; } - rs2 = Waveform::dft(rs2_reduced); + rs2 = Aux::fwd_r2c(m_dft, rs2_reduced); for (int icol = 0; icol != m_end_tick - m_start_tick; icol++) { resp_f_w(irow + 1, icol) = rs1[icol]; resp_f_w(end_ch - start_ch - 1 - irow + 2 * npad_wire, icol) = rs2[icol]; } - // for (int icol = 0; icol != nsamples; icol++){ - // resp_f_w(irow+1,icol) = rs1[icol]; - // resp_f_w(end_ch-start_ch-1-irow+2*npad_wire,icol) = rs2[icol]; - // } } // Do FFT on wire for response // slight larger - resp_f_w = Array::dft_cc(resp_f_w, 1); // Now becomes the f and f in both time and wire domain ... + // Now becomes the f and f in both time and wire domain ... + resp_f_w = Aux::fwd(m_dft, resp_f_w, 0); // multiply them together data_f_w = data_f_w * resp_f_w; } // Do inverse FFT on wire - data_f_w = Array::idft_cc(data_f_w, 1); + data_f_w = Aux::inv(m_dft, data_f_w, 0); // Add to wire result in frequency acc_data_f_w += data_f_w; } - // m_decon_data = Array::array_xxc::Zero(nwires,nsamples); - // if (npad_wire!=0){ - acc_data_f_w = Array::idft_cc(acc_data_f_w, 0); //.block(npad_wire,0,nwires,nsamples); + acc_data_f_w = Aux::inv(m_dft, acc_data_f_w, 1); Array::array_xxf real_m_decon_data = acc_data_f_w.real(); Array::array_xxf img_m_decon_data = acc_data_f_w.imag().colwise().reverse(); m_decon_data = real_m_decon_data + img_m_decon_data; - // std::cout << real_m_decon_data(40,5182) << " " << img_m_decon_data(40,5182) << std::endl; - // std::cout << real_m_decon_data(40,5182-m_start_tick) << " " << img_m_decon_data(40,5182-m_start_tick) << - // std::endl; - - //}else{ - // Array::array_xxc temp_m_decon_data = Array::idft_cc(acc_data_f_w,0); - // Array::array_xxf real_m_decon_data = temp_m_decon_data.real(); - // Array::array_xxf img_m_decon_data = temp_m_decon_data.imag().rowwise().reverse(); - // m_decon_data = real_m_decon_data + img_m_decon_data; - // } - - // // prepare FFT, loop 11 of them ... (older version) - // for (size_t i=0;i!=m_vec_vec_charge.size();i++){ - // // fill response array in frequency domain - // if (i!=10) continue; - - // Array::array_xxc data_f_w; - // { - // Array::array_xxf data_t_w = Array::array_xxf::Zero(nwires+2*npad_wire,nsamples); - // // fill charge array in time-wire domain // slightly larger - // for (size_t j=0;j!=m_vec_vec_charge.at(i).size();j++){ - // data_t_w(std::get<0>(m_vec_vec_charge.at(i).at(j))+npad_wire,std::get<1>(m_vec_vec_charge.at(i).at(j))) += - // std::get<2>(m_vec_vec_charge.at(i).at(j)); - // } - // m_vec_vec_charge.at(i).clear(); - - // // Do FFT on time - // data_f_w = Array::dft_rc(data_t_w,0); - // // Do FFT on wire - // data_f_w = Array::dft_cc(data_f_w,1); - // } - - // { - // Array::array_xxc resp_f_w = Array::array_xxc::Zero(nwires+2*npad_wire,nsamples); - // { - // Waveform::compseq_t rs1 = m_vec_map_resp.at(i)[0]->spectrum(); - // for (int icol = 0; icol != nsamples; icol++){ - // resp_f_w(0,icol) = rs1[icol]; - // } - // } - // for (int irow = 0; irow!=m_num_pad_wire;irow++){ - // Waveform::compseq_t rs1 = m_vec_map_resp.at(i)[irow+1]->spectrum(); - // Waveform::compseq_t rs2 = m_vec_map_resp.at(i)[-irow-1]->spectrum(); - // for (int icol = 0; icol != nsamples; icol++){ - // resp_f_w(irow+1,icol) = rs1[icol]; - // resp_f_w(nwires-1-irow+2*npad_wire,icol) = rs2[icol]; - // } - // } - // // Do FFT on wire for response // slight larger - // resp_f_w = Array::dft_cc(resp_f_w,1); // Now becomes the f and f in both time and wire domain ... - // // multiply them together - // data_f_w = data_f_w * resp_f_w; - // } - - // // Do inverse FFT on wire - // data_f_w = Array::idft_cc(data_f_w,1); - - // // Add to wire result in frequency - // acc_data_f_w += data_f_w; - // } - // m_vec_vec_charge.clear(); - - // // do inverse FFT on time for the final results ... - - // if (npad_wire!=0){ - // Array::array_xxf temp_m_decon_data = Array::idft_cr(acc_data_f_w,0); - // m_decon_data = temp_m_decon_data.block(npad_wire,0,nwires,nsamples); - // }else{ - // m_decon_data = Array::idft_cr(acc_data_f_w,0); - // } - - // std::cout << m_decon_data(40,5195-m_start_tick)/units::mV << " " << - // m_decon_data(40,5195-m_start_tick)/units::mV << std::endl; - - // m_vec_spmatrix.clear(); - // m_vec_spmatrix.shrink_to_fit(); - - // int nrows = resp_f_w.rows(); - // int ncols = resp_f_w.cols(); - // log->debug("ImpactTransform: # of channels: {} # of ticks: {}", m_decon_data.rows(), m_decon_data.cols()); - } // constructor Gen::ImpactTransform::~ImpactTransform() {} @@ -479,9 +317,6 @@ Waveform::realseq_t Gen::ImpactTransform::waveform(int iwire) const if (i >= m_start_tick && i < m_end_tick) { wf.at(i) = m_decon_data(iwire - m_start_ch, i - m_start_tick); } - else { - // wf.at(i) = 1e-25; - } // std::cout << m_decon_data(iwire-m_start_ch,i-m_start_tick) << std::endl; } @@ -489,19 +324,17 @@ Waveform::realseq_t Gen::ImpactTransform::waveform(int iwire) const // now convolute with the long-range response ... const size_t nlength = fft_best_length(nsamples + m_pir->closest(0)->long_aux_waveform_pad()); - // nlength = nsamples; - // std::cout << nlength << " " << nsamples + m_pir->closest(0)->long_aux_waveform_pad() << std::endl; wf.resize(nlength, 0); Waveform::realseq_t long_resp = m_pir->closest(0)->long_aux_waveform(); long_resp.resize(nlength, 0); - Waveform::compseq_t spec = Waveform::dft(wf); - Waveform::compseq_t long_spec = Waveform::dft(long_resp); + Waveform::compseq_t spec = Aux::fwd_r2c(m_dft, wf); + Waveform::compseq_t long_spec = Aux::fwd_r2c(m_dft, long_resp); for (size_t i = 0; i != nlength; i++) { spec.at(i) *= long_spec.at(i); } - wf = Waveform::idft(spec); + wf = Aux::inv_c2r(m_dft, spec); wf.resize(nsamples, 0); } diff --git a/gen/src/ImpactZipper.cxx b/gen/src/ImpactZipper.cxx deleted file mode 100644 index 193d0bed1..000000000 --- a/gen/src/ImpactZipper.cxx +++ /dev/null @@ -1,127 +0,0 @@ -#include "WireCellGen/ImpactZipper.h" -#include "WireCellUtil/Testing.h" - -#include // debugging. -using namespace std; - -using namespace WireCell; -Gen::ImpactZipper::ImpactZipper(IPlaneImpactResponse::pointer pir, BinnedDiffusion& bd) - : m_pir(pir) - , m_bd(bd) -{ -} - -Gen::ImpactZipper::~ImpactZipper() {} - -Waveform::realseq_t Gen::ImpactZipper::waveform(int iwire) const -{ - const double pitch_range = m_pir->pitch_range(); - - const auto pimpos = m_bd.pimpos(); - const auto rb = pimpos.region_binning(); - const auto ib = pimpos.impact_binning(); - const double wire_pos = rb.center(iwire); - - const int min_impact = ib.edge_index(wire_pos - 0.5 * pitch_range); - const int max_impact = ib.edge_index(wire_pos + 0.5 * pitch_range); - const int nsamples = m_bd.tbins().nbins(); - Waveform::compseq_t total_spectrum(nsamples, Waveform::complex_t(0.0, 0.0)); - - int nfound = 0; - const bool share = true; - // const Waveform::complex_t complex_one_half(0.5,0.0); - - // The BinnedDiffusion is indexed by absolute impact and the - // PlaneImpactResponse relative impact. - for (int imp = min_impact; imp <= max_impact; ++imp) { - // ImpactData - auto id = m_bd.impact_data(imp); - if (!id) { - // common as we are scanning all impacts covering a wire - // fixme: is there a way to predict this to avoid the query? - // std::cerr << "ImpactZipper: no data for absolute impact number: " << imp << std::endl; - continue; - } - - const Waveform::compseq_t& charge_spectrum = id->spectrum(); - // for interpolation - const Waveform::compseq_t& weightcharge_spectrum = id->weight_spectrum(); - - if (charge_spectrum.empty()) { - // should not happen - std::cerr << "ImpactZipper: no charge for absolute impact number: " << imp << std::endl; - continue; - } - if (weightcharge_spectrum.empty()) { - // weight == 0, should not happen - std::cerr << "ImpactZipper: no weight charge for absolute impact number: " << imp << std::endl; - continue; - } - - const double imp_pos = ib.center(imp); - const double rel_imp_pos = imp_pos - wire_pos; - // std::cerr << "IZ: " << " imp=" << imp << " imp_pos=" << imp_pos << " rel_imp_pos=" << rel_imp_pos << - // std::endl; - - Waveform::compseq_t conv_spectrum(nsamples, Waveform::complex_t(0.0, 0.0)); - if (share) { // fixme: make a configurable option - TwoImpactResponses two_ir = m_pir->bounded(rel_imp_pos); - if (!two_ir.first || !two_ir.second) { - // std::cerr << "ImpactZipper: no impact response for absolute impact number: " << imp << std::endl; - continue; - } - // fixme: this is average, not interpolation. - Waveform::compseq_t rs1 = two_ir.first->spectrum(); - Waveform::compseq_t rs2 = two_ir.second->spectrum(); - - for (int ind = 0; ind < nsamples; ++ind) { - // conv_spectrum[ind] = complex_one_half*(rs1[ind]+rs2[ind])*charge_spectrum[ind]; - - // linear interpolation: wQ*rs1 + (Q-wQ)*rs2 - conv_spectrum[ind] = weightcharge_spectrum[ind] * rs1[ind] + - (charge_spectrum[ind] - weightcharge_spectrum[ind]) * rs2[ind]; - /* debugging */ - /* if(iwire == 1000 && ind>1000 && ind<2000) { */ - /* std::cerr<<"rs1 spectrum: "<closest(rel_imp_pos); - if (!ir) { - // std::cerr << "ImpactZipper: no impact response for absolute impact number: " << imp << std::endl; - continue; - } - Waveform::compseq_t response_spectrum = ir->spectrum(); - for (int ind = 0; ind < nsamples; ++ind) { - conv_spectrum[ind] = response_spectrum[ind] * charge_spectrum[ind]; - } - } - - ++nfound; - // std::cerr << "ImpactZipper: found:"<(cfg, "dft", "FftwDFT"); + m_dft = Factory::find_tn(dft_tn); } bool Gen::Misconfigure::operator()(const input_pointer& in, output_pointer& out) @@ -72,9 +82,12 @@ bool Gen::Misconfigure::operator()(const input_pointer& in, output_pointer& out) size_t ntraces = traces->size(); ITrace::vector out_traces(ntraces); for (size_t ind = 0; ind < ntraces; ++ind) { - auto trace = traces->at(ind); + const auto& trace = traces->at(ind); - auto wave = Waveform::replace_convolve(trace->charge(), m_to, m_from, m_truncate); + // auto wave = Waveform::replace_convolve(trace->charge(), m_to, m_from, m_truncate); + const auto& charge = trace->charge(); + auto wave = Aux::replace(m_dft, charge, m_to, m_from); + wave.resize(charge.size()); out_traces[ind] = std::make_shared(trace->channel(), trace->tbin(), wave); } diff --git a/gen/src/Noise.cxx b/gen/src/Noise.cxx index 8fbb86ad0..f31cdd3eb 100644 --- a/gen/src/Noise.cxx +++ b/gen/src/Noise.cxx @@ -4,7 +4,7 @@ using namespace WireCell; -Waveform::realseq_t Gen::Noise::generate_waveform(const std::vector& spec, IRandom::pointer rng, double replace) +Waveform::compseq_t Gen::Noise::generate_spectrum(const std::vector& spec, IRandom::pointer rng, double replace) { // reuse randomes a bit to optimize speed. static std::vector random_real_part; @@ -49,6 +49,5 @@ Waveform::realseq_t Gen::Noise::generate_waveform(const std::vector& spec noise_freq.at(i + int(spec.size()) - shift).imag(random_imag_part.at(i) * amplitude); } - Waveform::realseq_t noise_time = WireCell::Waveform::idft(noise_freq); - return noise_time; + return noise_freq; } diff --git a/gen/src/Noise.h b/gen/src/Noise.h index a671dad40..4f8e21d4d 100644 --- a/gen/src/Noise.h +++ b/gen/src/Noise.h @@ -1,18 +1,16 @@ // This is some "private" code shared by a couple of components in gen. -// -// fixme: this is a candidate for turning into an interface. #include "WireCellIface/IRandom.h" #include "WireCellUtil/Waveform.h" #include -namespace WireCell { - namespace Gen { - namespace Noise { - // Generate a time series waveform given a spectral amplitude - WireCell::Waveform::realseq_t generate_waveform(const std::vector& spec, IRandom::pointer rng, - double replace = 0.02); - } // namespace Noise - } // namespace Gen -} // namespace WireCell +namespace WireCell::Gen::Noise { + // Generate a time series waveform given a spectral amplitude + // WireCell::Waveform::realseq_t generate_waveform(const std::vector& spec, IRandom::pointer rng, + // double replace = 0.02); + + // Generate specific noise spectrum. Caller likely wants to Aux::inv() it and take Waveform::real(). + WireCell::Waveform::compseq_t generate_spectrum(const std::vector& spec, IRandom::pointer rng, + double replace = 0.02); +} diff --git a/gen/src/NoiseSource.cxx b/gen/src/NoiseSource.cxx index 28e511044..76f6974d5 100644 --- a/gen/src/NoiseSource.cxx +++ b/gen/src/NoiseSource.cxx @@ -1,5 +1,7 @@ #include "WireCellGen/NoiseSource.h" +#include "WireCellAux/DftTools.h" + #include "WireCellIface/SimpleTrace.h" #include "WireCellIface/SimpleFrame.h" @@ -48,6 +50,7 @@ WireCell::Configuration Gen::NoiseSource::default_configuration() const cfg["anode"] = m_anode_tn; cfg["model"] = m_model_tn; cfg["rng"] = m_rng_tn; + cfg["dft"] = "FftwDFT"; // type-name for the DFT to use cfg["nsamples"] = m_nsamples; cfg["replacement_percentage"] = m_rep_percent; return cfg; @@ -60,6 +63,8 @@ void Gen::NoiseSource::configure(const WireCell::Configuration& cfg) if (!m_rng) { THROW(KeyError() << errmsg{"failed to get IRandom: " + m_rng_tn}); } + std::string dft_tn = get(cfg, "dft", "FftwDFT"); + m_dft = Factory::find_tn(dft_tn); m_anode_tn = get(cfg, "anode", m_anode_tn); m_anode = Factory::find_tn(m_anode_tn); @@ -104,7 +109,10 @@ bool Gen::NoiseSource::operator()(IFrame::pointer& frame) for (auto chid : m_anode->channels()) { const auto& spec = (*m_model)(chid); - Waveform::realseq_t noise = Gen::Noise::generate_waveform(spec, m_rng, m_rep_percent); + //Waveform::realseq_t noise = Gen::Noise::generate_waveform(spec, m_rng, m_rep_percent); + auto cnoise = Gen::Noise::generate_spectrum(spec, m_rng, m_rep_percent); + auto noise = Aux::inv_c2r(m_dft, cnoise); + // std::cout << noise.size() << " " << nsamples << std::endl; noise.resize(m_nsamples, 0); auto trace = make_shared(chid, tbin, noise); diff --git a/gen/src/PerChannelVariation.cxx b/gen/src/PerChannelVariation.cxx index 2af90496b..af702b752 100644 --- a/gen/src/PerChannelVariation.cxx +++ b/gen/src/PerChannelVariation.cxx @@ -1,9 +1,13 @@ #include "WireCellGen/PerChannelVariation.h" + +#include "WireCellAux/DftTools.h" + +#include "WireCellIface/SimpleFrame.h" +#include "WireCellIface/SimpleTrace.h" + #include "WireCellUtil/NamedFactory.h" #include "WireCellUtil/Response.h" #include "WireCellUtil/Waveform.h" -#include "WireCellIface/SimpleFrame.h" -#include "WireCellIface/SimpleTrace.h" #include @@ -41,11 +45,15 @@ WireCell::Configuration Gen::PerChannelVariation::default_configuration() const /// ch-by-ch electronics responses by calibration cfg["per_chan_resp"] = ""; + cfg["dft"] = "FftwDFT"; // type-name for the DFT to use return cfg; } void Gen::PerChannelVariation::configure(const WireCell::Configuration& cfg) { + std::string dft_tn = get(cfg, "dft", "FftwDFT"); + m_dft = Factory::find_tn(dft_tn); + m_per_chan_resp = get(cfg, "per_chan_resp", ""); if (!m_per_chan_resp.empty()) { @@ -64,6 +72,7 @@ void Gen::PerChannelVariation::configure(const WireCell::Configuration& cfg) m_truncate = cfg["truncate"].asBool(); } + bool Gen::PerChannelVariation::operator()(const input_pointer& in, output_pointer& out) { if (!in) { @@ -86,11 +95,14 @@ bool Gen::PerChannelVariation::operator()(const input_pointer& in, output_pointe size_t ntraces = traces->size(); ITrace::vector out_traces(ntraces); for (size_t ind = 0; ind < ntraces; ++ind) { - auto trace = traces->at(ind); + const auto& trace = traces->at(ind); auto chid = trace->channel(); Waveform::realseq_t tch_resp = m_cr->channel_response(chid); - tch_resp.resize(m_nsamples, 0); - auto wave = Waveform::replace_convolve(trace->charge(), tch_resp, m_from, m_truncate); + // tch_resp.resize(m_nsamples, 0); + // auto wave = Waveform::replace_convolve(trace->charge(), tch_resp, m_from, m_truncate); + const auto& charge = trace->charge(); + auto wave = Aux::replace(m_dft, charge, tch_resp, m_from); + wave.resize(charge.size()); out_traces[ind] = std::make_shared(chid, trace->tbin(), wave); } diff --git a/gen/src/PlaneImpactResponse.cxx b/gen/src/PlaneImpactResponse.cxx index c4ce24bf5..3f15c4349 100644 --- a/gen/src/PlaneImpactResponse.cxx +++ b/gen/src/PlaneImpactResponse.cxx @@ -1,10 +1,16 @@ +#include "WireCellGen/PlaneImpactResponse.h" + +#include "WireCellAux/DftTools.h" + #include "WireCellIface/IFieldResponse.h" #include "WireCellIface/IWaveform.h" -#include "WireCellGen/PlaneImpactResponse.h" +#include "WireCellIface/IDFT.h" + #include "WireCellUtil/Testing.h" #include "WireCellUtil/NamedFactory.h" #include "WireCellUtil/FFTBestLength.h" + WIRECELL_FACTORY(PlaneImpactResponse, WireCell::Gen::PlaneImpactResponse, WireCell::INamed, @@ -41,6 +47,7 @@ WireCell::Configuration Gen::PlaneImpactResponse::default_configuration() const cfg["nticks"] = 10000; // sample period of response waveforms cfg["tick"] = 0.5 * units::us; + cfg["dft"] = m_dftname; // type-name for the DFT to use return cfg; } @@ -73,11 +80,14 @@ void Gen::PlaneImpactResponse::configure(const WireCell::Configuration& cfg) m_nbins = (size_t) get(cfg, "nticks", (int) m_nbins); m_tick = get(cfg, "tick", m_tick); + m_dftname = get(cfg, "dft", m_dftname); build_responses(); } void Gen::PlaneImpactResponse::build_responses() { + auto dft = Factory::find_tn(m_dftname); + auto ifr = Factory::find_tn(m_frname); const size_t n_short_length = fft_best_length(m_overall_short_padding / m_tick); @@ -101,7 +111,7 @@ void Gen::PlaneImpactResponse::build_responses() } // note: we are ignoring waveform_start which will introduce // an arbitrary phase shift.... - auto spec = Waveform::dft(wave); + auto spec = Aux::fwd_r2c(dft, wave); for (size_t ibin = 0; ibin < n_short_length; ++ibin) { short_spec[ibin] *= spec[ibin]; } @@ -127,14 +137,15 @@ void Gen::PlaneImpactResponse::build_responses() } // note: we are ignoring waveform_start which will introduce // an arbitrary phase shift.... - auto spec = Waveform::dft(wave); + auto spec = Aux::fwd_r2c(dft, wave); for (size_t ibin = 0; ibin < n_long_length; ++ibin) { long_spec[ibin] *= spec[ibin]; } } WireCell::Waveform::realseq_t long_wf; - if (nlong > 0) long_wf = Waveform::idft(long_spec); - + if (nlong > 0) { + long_wf = Aux::inv_c2r(dft, long_spec); + } const auto& fr = ifr->field_response(); const auto& pr = *fr.plane(m_plane_ident); const int npaths = pr.paths.size(); @@ -219,7 +230,7 @@ void Gen::PlaneImpactResponse::build_responses() // sum up over coarse ticks. wave[bin] += induced_charge; } - WireCell::Waveform::compseq_t spec = Waveform::dft(wave); + WireCell::Waveform::compseq_t spec = Aux::fwd_r2c(dft, wave); // Convolve with short responses if (nshort) { @@ -227,12 +238,15 @@ void Gen::PlaneImpactResponse::build_responses() spec[find] *= short_spec[find]; } } - Waveform::realseq_t wf = Waveform::idft(spec); + Waveform::realseq_t wf = Aux::inv_c2r(dft, spec); + wf.resize(m_nbins, 0); + spec = Aux::fwd_r2c(dft, wf); IImpactResponse::pointer ir = std::make_shared( - ipath, wf, m_overall_short_padding / m_tick, + ipath, + spec, wf, m_overall_short_padding / m_tick, long_wf, m_long_padding / m_tick); m_ir.push_back(ir); } diff --git a/gen/src/TruthSmearer.cxx b/gen/src/TruthSmearer.cxx index eb70632eb..9bc19459c 100644 --- a/gen/src/TruthSmearer.cxx +++ b/gen/src/TruthSmearer.cxx @@ -1,6 +1,6 @@ #include "WireCellGen/TruthSmearer.h" #include "WireCellGen/BinnedDiffusion.h" -#include "WireCellGen/ImpactZipper.h" + #include "WireCellUtil/Units.h" #include "WireCellUtil/Point.h" #include "WireCellUtil/NamedFactory.h" @@ -97,6 +97,9 @@ WireCell::Configuration Gen::TruthSmearer::default_configuration() const put(cfg, "anode", m_anode_tn); put(cfg, "rng", m_rng_tn); + // Name for the DFTer + cfg["dft"] = "FftwDFT"; + return cfg; } @@ -118,6 +121,9 @@ void Gen::TruthSmearer::configure(const WireCell::Configuration& cfg) m_rng = Factory::find_tn(m_rng_tn); } + std::string dft_tn = get(cfg, "dft", "FftwDFT"); + m_dft = Factory::find_tn(dft_tn); + m_readout_time = get(cfg, "readout_time", m_readout_time); m_tick = get(cfg, "tick", m_tick); m_start_time = get(cfg, "start_time", m_start_time); @@ -165,7 +171,7 @@ void Gen::TruthSmearer::process(output_queue& frames) tick = tbins.binsize(); } - Gen::BinnedDiffusion bindiff(*pimpos, tbins, m_nsigma, m_rng); + Gen::BinnedDiffusion bindiff(*pimpos, m_dft, tbins, m_nsigma, m_rng); for (auto depo : face_depos) { // time filter smearing double extent_time = depo->extent_long() / m_drift_speed; @@ -193,9 +199,6 @@ void Gen::TruthSmearer::process(output_queue& frames) const double impact = ib.binsize(); const int nwires = rb.nbins(); for (int iwire = 0; iwire < nwires; ++iwire) { - /// Similar to ImpactZipper::waveform - /// No convolution - /// m_waveform from BinnedDiffusion::impact_data() const double wire_pos = rb.center(iwire); @@ -239,7 +242,7 @@ void Gen::TruthSmearer::process(output_queue& frames) Waveform::realseq_t charge_spectrum = id->waveform(); if (charge_spectrum.empty()) { - std::cerr << "impactZipper: no charge spectrum for absolute impact number: " << imp << endl; + std::cerr << "TruthSmearer: no charge spectrum for absolute impact number: " << imp << endl; continue; } diff --git a/gen/src/TruthTraceID.cxx b/gen/src/TruthTraceID.cxx index df60ecba1..209c6c7f9 100644 --- a/gen/src/TruthTraceID.cxx +++ b/gen/src/TruthTraceID.cxx @@ -1,10 +1,14 @@ #include "WireCellGen/TruthTraceID.h" #include "WireCellGen/BinnedDiffusion.h" + +#include "WireCellAux/DftTools.h" + +#include "WireCellIface/SimpleTrace.h" +#include "WireCellIface/SimpleFrame.h" + #include "WireCellUtil/Units.h" #include "WireCellUtil/Point.h" #include "WireCellUtil/NamedFactory.h" -#include "WireCellIface/SimpleTrace.h" -#include "WireCellIface/SimpleFrame.h" #include @@ -54,6 +58,7 @@ WireCell::Configuration Gen::TruthTraceID::default_configuration() const put(cfg, "first_frame_number", m_frame_count); put(cfg, "anode", m_anode_tn); put(cfg, "rng", m_rng_tn); + put(cfg, "dft", "FftwDFT"); // type-name for the DFT to use put(cfg, "truth_type", m_truth_type); put(cfg, "number_induction_wire", m_num_ind_wire); put(cfg, "number_collection_wire", m_num_col_wire); @@ -86,6 +91,8 @@ void Gen::TruthTraceID::configure(const WireCell::Configuration& cfg) m_rng_tn = get(cfg, "rng", m_rng_tn); m_rng = Factory::find_tn(m_rng_tn); } + std::string dft_tn = get(cfg, "dft", "FftwDFT"); + m_dft = Factory::find_tn(dft_tn); m_readout_time = get(cfg, "readout_time", m_readout_time); m_tick = get(cfg, "tick", m_tick); @@ -139,7 +146,7 @@ void Gen::TruthTraceID::process(output_queue& frames) auto timeTruth = hf_time.generate(timeBins); // ### apply diffusion at wire plane ### - Gen::BinnedDiffusion bindiff(*pimpos, tbins, m_nsigma, m_rng); + Gen::BinnedDiffusion bindiff(*pimpos, m_dft, tbins, m_nsigma, m_rng); for (auto depo : m_depos) { bindiff.add(depo, depo->extent_long() / m_drift_speed, depo->extent_tran()); @@ -193,8 +200,7 @@ void Gen::TruthTraceID::process(output_queue& frames) } bindiff.erase(0, min_impact); - Waveform::realseq_t wave(nsamples, 0.0); - wave = Waveform::idft(total_spectrum); + Waveform::realseq_t wave = Aux::inv_c2r(m_dft, total_spectrum); auto mm = Waveform::edge(wave); if (mm.first == (int) wave.size()) { continue; diff --git a/gen/test/test_empnomo.cxx b/gen/test/test_empnomo.cxx index 2e7269761..0565647a6 100644 --- a/gen/test/test_empnomo.cxx +++ b/gen/test/test_empnomo.cxx @@ -5,6 +5,7 @@ #include "WireCellUtil/NamedFactory.h" #include "WireCellIface/IChannelStatus.h" #include "WireCellIface/IChannelSpectrum.h" +#include "WireCellIface/IDFT.h" #include #include @@ -16,6 +17,9 @@ using namespace WireCell; int main(int argc, char* argv[]) { + PluginManager& pm = PluginManager::instance(); + pm.add("WireCellAux"); + std::string detector = "uboone"; // In the real WCT this is done by wire-cell and driven by user @@ -25,6 +29,9 @@ int main(int argc, char* argv[]) cerr << "Using AnodePlane: \"" << anode_tns[0] << "\"\n"; { + { + Factory::lookup_tn("FftwDFT"); + } { auto icfg = Factory::lookup("StaticChannelStatus"); // In the real app this would be in a JSON or Jsonnet config diff --git a/gen/test/test_pir.cxx b/gen/test/test_pir.cxx index 0846c6a70..016aba3ac 100644 --- a/gen/test/test_pir.cxx +++ b/gen/test/test_pir.cxx @@ -1,7 +1,10 @@ +#include "WireCellGen/PlaneImpactResponse.h" + +#include "WireCellIface/IDFT.h" + #include "WireCellUtil/PluginManager.h" #include "WireCellUtil/Testing.h" #include "WireCellUtil/NamedFactory.h" -#include "WireCellGen/PlaneImpactResponse.h" #include "WireCellUtil/Units.h" #include "WireCellUtil/Logging.h" @@ -16,6 +19,7 @@ int main(int argc, char* argv[]) { Log::set_level("debug"); PluginManager& pm = PluginManager::instance(); + pm.add("WireCellAux"); pm.add("WireCellGen"); pm.add("WireCellSigProc"); @@ -24,6 +28,9 @@ int main(int argc, char* argv[]) { response_file = argv[1]; }; + { + Factory::lookup_tn("FftwDFT"); + } { auto icfg = Factory::lookup("FieldResponse"); auto cfg = icfg->default_configuration(); diff --git a/iface/README.org b/iface/README.org index 37e27b412..cf76f4926 100644 --- a/iface/README.org +++ b/iface/README.org @@ -25,3 +25,62 @@ the overall WCT dependency tree. Discussion is warranted in these cases. See the user manual for more info. https://wirecell.bnl.gov/ +* Interfaces + +** IDFT + +The ~IDFT~ class provides interface to methods to perform discrete +Fourier transforms on arrays of complex single precision floating +point values. + +The interface defines a number of methods which take a general naming +convention like: +#+begin_example +void (...); +#+end_example + +The "direction" of the transform is one of + +- fwd :: the DFT is from interval to frequency, no normalization. +- inv :: the DFT is from frequency to interval, 1/n normalization. + +The "domain" determines the dimension of array and how it is transformed + +- 1d :: a 1D array is transformed +- 1b :: a batch of equal-length 1D arrays are transformed +- 2d :: a 2D array is transformed (along both dimensions) + +The shape of 2D arrays (~1b~ or ~2d~ methods) are given in terms of two +numbers: ~stride~ and ~nstrides~. The number ~stride~ counts the number of +contiguous array elements along one dimension and ~nstrides~ counts the +number non-contiguous elements logically along the opposite dimension. +In the case of "row-major" aka "C" memory ordering of 2D arrays, the +number ~stride~ counts the number of elements in one "row" and ~nstrides~ +counts the number of rows, aka, the number of elements in one column. + +The ~1b~ transforms operate along a contiguous array of length ~stride~. +By default, these transforms are implemented in terms of ~nstrides~ +calls to the ~1d~ DFT interface method. The implementation may override +the ~1b~ default methods for example to exploit some kind of "batch +optimization". + +*** Limitations + +- The potential speed up when the input to a forward or output from + reverse is real valued is not possible to implement with ~IDFT~. It + requires the caller to take particular care in array sizes and would + double the number of methods. + +- To satisfy the low-level pointer to memory interface from higher + level objects see the ~Waveform.h~ and ~Array.h~ headers in + ~WireCellUtil~. In particular, see functions there to lift real to + complex or perform memory transforms. + +- Interface to higher order transforms, such as convolutions, are not + provided. See ~Aux::DFT~ for implementations in terms of an ~IDFT~. + +** ... + +Any interfaces not listed above, please see their header file in +[[file:inc/WireCellIface/][inc/WireCellIface/]] for more information. + diff --git a/iface/inc/WireCellIface/IDFT.h b/iface/inc/WireCellIface/IDFT.h new file mode 100644 index 000000000..e94afb391 --- /dev/null +++ b/iface/inc/WireCellIface/IDFT.h @@ -0,0 +1,131 @@ +#ifndef WIRECELL_IDFT +#define WIRECELL_IDFT + +#include "WireCellUtil/IComponent.h" +#include + +namespace WireCell { + + /** + Interface to perform discrete Fourier transforms on arrays of + signal precision, complex floating point values. + + There are 6 DFT methods which are formed as the outer product + of two lists: + + - fwd, inv + - 1d, 1b, 2d + + The "fwd" methods provide forward transform, no normalization. + The "inv" methods provide reverse/inverse transform normalized + by 1/size. + + The 1d transforms take rank=1 / 1D arrays and perform a single + transform. + + The 2d transforms take rank=2 / 2D arrays and perform nrows of + transforms along rows and ncols of transforms along columns. + The order over which each dimension is transformed is + implementation-defined (and imaterial). + + The 1b transforms take rank=1 / 2D arrays and perform + transforms along a single dimension as determined by the value + of the "axis" parameter. An axis=1 means to perform nrows + transforms along rows. Note, this is the same convention + followed by numpy.fft functions. + + There is also a special rank=0 DFT on rank=2 arrays which is + more commonly known as a "matrix transpose". + + Requirements on implementations: + + - Forward transforms SHALL NOT apply normalization. + + - Reverse transforms SHALL apply 1/n normalization. + + - The arrays SHALL be assumed to follow C-ordering aka + row-major storage order. + + - Transform methods SHALL allow the input and output array + pointers to be identical. + + - The IDFT interface provides 1b methods implemented in terms + of 1d calls and a implementation MAY override these (for + example, if implementation can exploit batch optimization). + + - Implementation SHALL allow safe concurrent calls to methods + by different threads of execution. + + Requirement on callers. + + - Input and output arrays SHALL be pre-allocated and be sized + at least as large as indicated by accompanying size arguments. + + - Input and output arrays MUST either be non-overlapping in + memory or MUST be identical. + + Notes: + + - All arrays are of type single precision complex floating + point. Functions and methods to easily convert between the + two exist. + + - Eigen arrays are column-wise by default and so their + arr.data() method can not directly supply input to this + interface. Likewise, use of arr.transpose().data() may run + afowl of Eigen's IsRowMajor optimization flag. Copy your + default array in a Eigen::RowMajor array first or use IDFT + via Aux::DftTools functions. + */ + class IDFT : public IComponent { + public: + virtual ~IDFT(); + + /// The type for the signal in each bin. + using scalar_t = float; + + /// The type for the spectrum in each bin. + using complex_t = std::complex; + + // 1d + + virtual + void fwd1d(const complex_t* in, complex_t* out, int size) const = 0; + + virtual + void inv1d(const complex_t* in, complex_t* out, int size) const = 0; + + // 1b + + virtual + void fwd1b(const complex_t* in, complex_t* out, + int nrows, int ncols, int axis) const; + + virtual + void inv1b(const complex_t* in, complex_t* out, + int nrows, int ncols, int axis) const; + + // 2d + + virtual + void fwd2d(const complex_t* in, complex_t* out, + int nrows, int ncols) const = 0; + virtual + void inv2d(const complex_t* in, complex_t* out, + int nrows, int ncols) const = 0; + + + // Fill "out" with the transpose of "in", may be in-place. + // The nrows/ncols refers to the shape of the input. + virtual + void transpose(const scalar_t* in, scalar_t* out, + int nrows, int ncols) const; + virtual + void transpose(const complex_t* in, complex_t* out, + int nrows, int ncols) const; + + }; +} + + +#endif diff --git a/iface/inc/WireCellIface/ISemaphore.h b/iface/inc/WireCellIface/ISemaphore.h new file mode 100644 index 000000000..99a55396d --- /dev/null +++ b/iface/inc/WireCellIface/ISemaphore.h @@ -0,0 +1,31 @@ +/** An interface to the semaphore pattern */ + +#ifndef WIRECELL_ISEMAPHORE +#define WIRECELL_ISEMAPHORE + +#include "WireCellUtil/IComponent.h" + +namespace WireCell { + class ISemaphore : public IComponent { + public: + virtual ~ISemaphore(); + + /// Block until available spot to hold the semaphore is + /// available. + virtual void acquire() const = 0; + + /// Release hold on the semaphore + virtual void release() const = 0; + + /// Use Construct a Context on a semaphore in a local scope to + /// automate release + struct Context { + ISemaphore::pointer sem; + Context(ISemaphore::pointer sem) : sem(sem) { sem->acquire(); } + ~Context() { sem->release(); } + }; + + }; +} // namespace WireCell + +#endif // WIRECELL_ITENSORFORWARD diff --git a/iface/inc/WireCellIface/ITensor.h b/iface/inc/WireCellIface/ITensor.h index 5154b68fd..d1cbf6adb 100644 --- a/iface/inc/WireCellIface/ITensor.h +++ b/iface/inc/WireCellIface/ITensor.h @@ -19,7 +19,22 @@ namespace WireCell { public: /// Shape gives size of each dimension. Size of shape give Ndim. typedef std::vector shape_t; - /// Storage order. Empty implies C order. + + /// Storage order. Empty implies C order. If non-empty the + /// vector holds the "majority" of the dimension. C-order + /// implies a vector of {1,0} which means if the array is + /// accessed as array[a][b] "axis" 0 (indexed by "a") is the + /// "major index" and "axis" 1 (indexed by "b") is the "minor + /// index". It is thus "row-major" ordering as the major + /// index counts rows. An array in fortran-order + /// (column-major order) would be given as {0,1}. + /// + /// A note as this can be confusing: The "logical" rows and + /// columns, eg when used in an Eigen array are independent + /// from memory order. An Eigen array is always indexed as + /// arr(r,c). Storage order only matters when, well, you + /// access the array storage such as from Eigen array's + /// .data() method - and indeed ITensor::data(). typedef std::vector order_t; /// The type of the element. diff --git a/iface/src/IDFT.cxx b/iface/src/IDFT.cxx new file mode 100644 index 000000000..abc3d6120 --- /dev/null +++ b/iface/src/IDFT.cxx @@ -0,0 +1,95 @@ +#include "WireCellIface/IDFT.h" + +#include +#include // std::swap since c++11 + +using namespace WireCell; + +IDFT::~IDFT() {} + +// Trivial default "batched" implementations. If your concrete +// implementation provides some kind of "batch optimization", such as +// with FFTW3's advanced interface or with some GPU FFT library, +// override these dumb methods for the win. + +void IDFT::fwd1b(const complex_t* in, complex_t* out, + int nrows, int ncols, int axis) const +{ + if (axis) { + for (int irow=0; irowtranspose(in, out, nrows, ncols); + this->fwd1b(out, out, ncols, nrows, 1); + this->transpose(out, out, ncols, nrows); + } +} + +void IDFT::inv1b(const complex_t* in, complex_t* out, + int nrows, int ncols, int axis) const +{ + if (axis) { + for (int irow=0; irowtranspose(in, out, nrows, ncols); + this->inv1b(out, out, ncols, nrows, 1); + this->transpose(out, out, ncols, nrows); + } +} + +// Trivial default transpose. Implementations, please override if you +// can offer something faster. + +template +void transpose_type(const ValueType* in, ValueType* out, + int nrows, int ncols) +{ + if (in != out) { + for (int irow=0; irow visited(size); + ValueType* first = out; + const ValueType* last = first + size; + ValueType* cycle = out; + while (++cycle != last) { + if (visited[cycle - first]) + continue; + int a = cycle - first; + do { + a = a == mn1 ? mn1 : (n * a) % mn1; + std::swap(*(first + a), *cycle); + visited[a] = true; + } while ((first + a) != cycle); + } + +} + + +void IDFT::transpose(const IDFT::scalar_t* in, IDFT::scalar_t* out, + int nrows, int ncols) const +{ + transpose_type(in, out, nrows, ncols); +} +void IDFT::transpose(const IDFT::complex_t* in, IDFT::complex_t* out, + int nrows, int ncols) const +{ + transpose_type(in, out, nrows, ncols); +} diff --git a/iface/src/IfaceDesctructors.cxx b/iface/src/IfaceDesctructors.cxx index 76efb5ecd..59b1597d1 100644 --- a/iface/src/IfaceDesctructors.cxx +++ b/iface/src/IfaceDesctructors.cxx @@ -74,6 +74,7 @@ #include "WireCellIface/IRandom.h" #include "WireCellIface/IRecombinationModel.h" #include "WireCellIface/IScalarFieldSink.h" +#include "WireCellIface/ISemaphore.h" #include "WireCellIface/ISequence.h" #include "WireCellIface/ISinkNode.h" #include "WireCellIface/ISlice.h" @@ -172,6 +173,7 @@ IQueuedoutNodeBase::~IQueuedoutNodeBase() {} IRandom::~IRandom() {} IRecombinationModel::~IRecombinationModel() {} IScalarFieldSink::~IScalarFieldSink() {} +ISemaphore::~ISemaphore() {} ISinkNodeBase::~ISinkNodeBase() {} ISlice::~ISlice() {} ISliceFanout::~ISliceFanout() {} diff --git a/pytorch/inc/WireCellPytorch/DFT.h b/pytorch/inc/WireCellPytorch/DFT.h new file mode 100644 index 000000000..41e154152 --- /dev/null +++ b/pytorch/inc/WireCellPytorch/DFT.h @@ -0,0 +1,62 @@ +/** + TorchDFT provides a libtorch based implementation of IDFT. + + The libtorch API is documented at: + + https://pytorch.org/cppdocs/api/namespace_torch__fft.html + */ + +#ifndef WIRECELL_PYTORCH_DFT +#define WIRECELL_PYTORCH_DFT + +#include "WireCellIface/IDFT.h" +#include "WireCellIface/IConfigurable.h" +#include "WireCellPytorch/TorchContext.h" + +namespace WireCell::Pytorch { + class DFT : public IDFT, + public IConfigurable + { + public: + DFT(); + virtual ~DFT(); + + // IConfigurable interface + virtual void configure(const WireCell::Configuration& config); + virtual WireCell::Configuration default_configuration() const; + + // 1d + + virtual + void fwd1d(const complex_t* in, complex_t* out, + int size) const; + + virtual + void inv1d(const complex_t* in, complex_t* out, + int size) const; + + // batched 1D ("1b") - rely on base implementation + virtual + void fwd1b(const complex_t* in, complex_t* out, + int nrows, int ncols, int axis) const; + virtual + void inv1b(const complex_t* in, complex_t* out, + int nrows, int ncols, int axis) const; + + // 2d + + virtual + void fwd2d(const complex_t* in, complex_t* out, + int nrows, int ncols) const; + virtual + void inv2d(const complex_t* in, complex_t* out, + int nrows, int ncols) const; + + private: + TorchContext m_ctx; + + }; + +} + +#endif diff --git a/pytorch/inc/WireCellPytorch/TorchContext.h b/pytorch/inc/WireCellPytorch/TorchContext.h new file mode 100644 index 000000000..6df72d090 --- /dev/null +++ b/pytorch/inc/WireCellPytorch/TorchContext.h @@ -0,0 +1,62 @@ +/** A mixin class to provide a torch context + + */ + +#include "WireCellIface/ISemaphore.h" +#include + +namespace WireCell::Pytorch { + + class TorchContext { + public: + + // The "devname" is "cpu" or "gpu" or "gpuN" where N is a GPU + // number. If "semname" is given, use it for semaphore, + // otherwise use canonically tn=Semaphore:torch-. + TorchContext(const std::string& devname, + const std::string& semname=""); + TorchContext(); + ~TorchContext(); + + // Default constructor makes context with no device nor + // semaphore. This will make the "connection" to them. + void connect(const std::string& devname, + const std::string& semname=""); + + torch::Device device() const { return m_dev; } + std::string devname() const { return m_devname; } + + bool is_gpu() const { return m_devname != "cpu"; } + + // Context manager methods. Caller should prefer using a + // TorchSemaphore class but if called directly, caller MUST + // balance an enter() with an exit(). These can and should be + // used in multi-thread run stage. + void enter() const { if (m_sem) m_sem->acquire(); } + void exit() const { if (m_sem) m_sem->release(); } + + private: + + torch::Device m_dev{torch::kCPU}; + std::string m_devname{""}; + ISemaphore::pointer m_sem; + }; + + /// Use like: + /// + /// void mymeth() { + /// TorchSemaphore sem(m_ctx); + /// ... more code may return/throw + /// } // end of scope + class TorchSemaphore { + const TorchContext& m_th; + public: + TorchSemaphore(const TorchContext& th) : m_th(th) { + m_th.enter(); + } + ~TorchSemaphore() { + m_th.exit(); + } + }; + +} diff --git a/pytorch/inc/WireCellPytorch/TorchService.h b/pytorch/inc/WireCellPytorch/TorchService.h index 01504724e..ca8fe4eb1 100644 --- a/pytorch/inc/WireCellPytorch/TorchService.h +++ b/pytorch/inc/WireCellPytorch/TorchService.h @@ -7,7 +7,7 @@ #include "WireCellIface/ITensorForward.h" #include "WireCellUtil/Logging.h" #include "WireCellAux/Logger.h" -#include "WireCellUtil/Semaphore.h" +#include "WireCellPytorch/TorchContext.h" #include // One-stop header. @@ -29,15 +29,14 @@ namespace WireCell::Pytorch { private: - // Mark which device is used - torch::Device m_dev; - // for read-only access, claim is that .forward() is thread // safe. However .forward() is not const so we must make this // mutable. mutable torch::jit::script::Module m_module; - mutable FastSemaphore m_sem; + // Even though thread safe, we want to honor a per device + // semaphore to give user chance ot limit us. + TorchContext m_ctx; }; } // namespace WireCell::Pytorch diff --git a/pytorch/src/DFT.cxx b/pytorch/src/DFT.cxx new file mode 100644 index 000000000..052013ded --- /dev/null +++ b/pytorch/src/DFT.cxx @@ -0,0 +1,124 @@ +#include "WireCellPytorch/DFT.h" +#include "WireCellUtil/NamedFactory.h" + +#include +#include + + +WIRECELL_FACTORY(TorchDFT, WireCell::Pytorch::DFT, + WireCell::IDFT, + WireCell::IConfigurable) + +using namespace WireCell; +using namespace WireCell::Pytorch; + +DFT::DFT() +{ +} + +DFT::~DFT() +{ +} + +Configuration DFT::default_configuration() const +{ + Configuration cfg; + + // one of: {cpu, gpu, gpuN} where "N" is a GPU number. "gpu" + // alone will use GPU 0. + cfg["device"] = "cpu"; + return cfg; +} + +void DFT::configure(const WireCell::Configuration& cfg) +{ + auto dev = get(cfg, "device", "cpu"); + m_ctx.connect(dev); +} + + +using torch_transform = std::function; + +static +void doit(const TorchContext& ctx, + const IDFT::complex_t* in, IDFT::complex_t* out, + int64_t nrows, int64_t ncols, // 1d vec should have nrows=1 + torch_transform func) +{ + TorchSemaphore sem(ctx); + torch::NoGradGuard no_grad; + + int64_t size = nrows*ncols; + + auto dtype = torch::TensorOptions().dtype(torch::kComplexFloat); + + // 1) in->src + if (in != out) { // from_blob() doesn't like const data + memcpy(out, in, sizeof(IDFT::complex_t)*size); + } + + torch::Tensor src = torch::from_blob(out, {nrows, ncols}, dtype); + + // 2) dst = func(src) + src = src.to(ctx.device()); + auto dst = func(src); + + // Making contiguous costs a copy but gets the data in row-major + // so the (2nd) copy next actually gives correct results. This + // corrects optimizations that libtorch makes for transpose (and + // others) eg when our func is 1b. Alternatively, may avoid both + // copies by iterating over indices but presumably (?) that is + // slower. Likewise we make contiguous on the device as that is + // presumably (?) faster when the device is GPU. + dst = dst.contiguous().cpu(); + + // 3) dst->out + memcpy(out, dst.data_ptr(), sizeof(IDFT::complex_t)*size); +} + + +void DFT::fwd1d(const IDFT::complex_t* in, IDFT::complex_t* out, int size) const +{ + doit(m_ctx, in, out, 1, size, + [](const torch::Tensor& src) { return torch::fft::fft(src); }); +} + + +void DFT::inv1d(const IDFT::complex_t* in, IDFT::complex_t* out, int size) const +{ + doit(m_ctx, in, out, 1, size, // fixme: check norm + [](const torch::Tensor& src) { return torch::fft::ifft(src); }); +} + + +void DFT::fwd1b(const IDFT::complex_t* in, IDFT::complex_t* out, + int nrows, int ncols, int axis) const +{ + doit(m_ctx, in, out, nrows, ncols, [&](const torch::Tensor& src) { + return torch::fft::fft2(src, torch::nullopt, {axis}); }); +} + + +void DFT::inv1b(const IDFT::complex_t* in, IDFT::complex_t* out, + int nrows, int ncols, int axis) const +{ + doit(m_ctx, in, out, nrows, ncols, [&](const torch::Tensor& src) { + return torch::fft::ifft2(src, torch::nullopt, {axis}); }); +} + + +void DFT::fwd2d(const IDFT::complex_t* in, IDFT::complex_t* out, + int nrows, int ncols) const +{ + doit(m_ctx, in, out, nrows, ncols, + [](const torch::Tensor& src) { return torch::fft::fft2(src); }); +} + + +void DFT::inv2d(const IDFT::complex_t* in, IDFT::complex_t* out, + int nrows, int ncols) const +{ + doit(m_ctx, in, out, nrows, ncols, + [](const torch::Tensor& src) { return torch::fft::ifft2(src); }); +} + diff --git a/pytorch/src/DNNROIFinding.cxx b/pytorch/src/DNNROIFinding.cxx index c875f9714..f3fda52df 100644 --- a/pytorch/src/DNNROIFinding.cxx +++ b/pytorch/src/DNNROIFinding.cxx @@ -234,7 +234,8 @@ bool Pytorch::DNNROIFinding::operator()(const IFrame::pointer& inframe, IFrame:: std::vector inputs; inputs.push_back(batch); - log->debug(tk(fmt::format("call={} calling model", m_save_count))); + log->debug(tk(fmt::format("call={} calling model \"{}\"", + m_save_count, m_cfg.forward))); // Execute the model and turn its output into a tensor. auto iitens = Pytorch::to_itensor(inputs); diff --git a/pytorch/src/TorchContext.cxx b/pytorch/src/TorchContext.cxx new file mode 100644 index 000000000..72262f24b --- /dev/null +++ b/pytorch/src/TorchContext.cxx @@ -0,0 +1,39 @@ +#include "WireCellPytorch/TorchContext.h" +#include "WireCellUtil/NamedFactory.h" + +using namespace WireCell; +using namespace WireCell::Pytorch; + +TorchContext::TorchContext() {} +TorchContext::~TorchContext() { } +TorchContext::TorchContext(const std::string& devname, + const std::string& semname) +{ + connect(devname, semname); +} +void TorchContext::connect(const std::string& devname, + const std::string& semname) +{ + // Use almost 1/2 the memory and 3/4 the time. + torch::NoGradGuard no_grad; + + if (devname == "cpu") { + m_dev = torch::Device(torch::kCPU); + } + else { + int devnum = 0; + if (devname.size() > 3) { + devnum = atoi(devname.substr(3).c_str()); + } + m_dev = torch::Device(torch::kCUDA, devnum); + } + m_devname = devname; + + std::string s_tn = "Semaphore:torch-" + devname; + if (not semname.empty()) { + s_tn = semname; + } + + m_sem = Factory::lookup_tn(s_tn); +} + diff --git a/pytorch/src/TorchService.cxx b/pytorch/src/TorchService.cxx index 76bbfe436..c09243657 100644 --- a/pytorch/src/TorchService.cxx +++ b/pytorch/src/TorchService.cxx @@ -14,8 +14,6 @@ using namespace WireCell; Pytorch::TorchService::TorchService() : Aux::Logger("TorchService", "torch") - , m_dev(torch::kCPU, 0) - , m_sem(0) { } @@ -26,30 +24,17 @@ Configuration Pytorch::TorchService::default_configuration() const // TorchScript model cfg["model"] = "model.ts"; - // one of: {cpu, gpu, gpucpu}. Latter allows fail-over to cpu - // when there is a failure to load the model. - // fixme: we may want to allow user to give a GPU index number - // here so like eg gpu:1, gpucpu:2. An index is not meaningful - // for cpu. - cfg["device"] = "gpucpu"; + // one of: {cpu, gpu, gpuN} where "N" is a GPU number. "gpu" + // alone will use GPU 0. + cfg["device"] = "cpu"; - // The maximum allowed number concurrent calls to forward(). A - // value of unity means all calls will be serialized. When made - // smaller than the number of threads, the difference gives the - // number of threads that may block on the semaphore. - cfg["concurrency"] = 1; - return cfg; } void Pytorch::TorchService::configure(const WireCell::Configuration& cfg) { - auto dev = get(cfg, "device", "gpucpu"); - auto count = get(cfg, "concurrency", 1); - if (count < 1 ) { - count = 1; - } - m_sem.set_count(count); + auto dev = get(cfg, "device", "cpu"); + m_ctx.connect(dev); auto model_path = cfg["model"].asString(); if (model_path.empty()) { @@ -58,74 +43,42 @@ void Pytorch::TorchService::configure(const WireCell::Configuration& cfg) } // Use almost 1/2 the memory and 3/4 the time. - // but, fixme: check with Haiwng that this is okay. torch::NoGradGuard no_grad; - // Maybe first try to load torch script model on GPU. - if (dev == "gpucpu") { - try { - m_dev = torch::Device(torch::kCUDA, 0); - m_module = torch::jit::load(model_path, m_dev); - log->debug("loaded model {} to {}", model_path, dev); - return; - } - catch (const c10::Error& e) { - log->warn("failed to load model: {} to GPU will try CPU: {}", - model_path, e.what()); - } - } - - if (dev == "cpu") { - m_dev = torch::Device(torch::kCPU); - } - else { - m_dev = torch::Device(torch::kCUDA, 0); - } - - // from now, we either succeed or we throw - try { - m_module = torch::jit::load(model_path, m_dev); + m_module = torch::jit::load(model_path, m_ctx.device()); } catch (const c10::Error& e) { - log->critical("error loading model: {} to {}: {}", + log->critical("error loading model: \"{}\" to device \"{}\": {}", model_path, dev, e.what()); throw; // rethrow } - log->debug("loaded model {} to {}", model_path, dev); + log->debug("loaded model \"{}\" to device \"{}\"", + model_path, m_ctx.devname()); } -#include - ITensorSet::pointer Pytorch::TorchService::forward(const ITensorSet::pointer& in) const { + TorchSemaphore sem(m_ctx); - m_sem.acquire(); - - const bool is_gpu = ! (m_dev == torch::kCPU); - log->debug("running model on {}", is_gpu ? "GPU" : "CPU"); + log->debug("running model on device: \"{}\"", m_ctx.devname()); torch::NoGradGuard no_grad; - std::vector iival = Pytorch::from_itensor(in, is_gpu); + std::vector iival = Pytorch::from_itensor(in, m_ctx.is_gpu()); torch::IValue oival; try { oival = m_module.forward(iival); } catch (const std::runtime_error& err) { - log->error("error running model on {}: {}", - is_gpu ? "GPU" : "CPU", err.what()); - m_sem.release(); + log->error("error running model on device \"{}\": {}", + m_ctx.devname(), err.what()); return nullptr; } ITensorSet::pointer ret = Pytorch::to_itensor({oival}); - // maybe needs a mutex? - c10::cuda::CUDACachingAllocator::emptyCache(); - - m_sem.release(); return ret; } diff --git a/pytorch/test/test_from_blob.cxx b/pytorch/test/test_from_blob.cxx new file mode 100644 index 000000000..b6e45e23e --- /dev/null +++ b/pytorch/test/test_from_blob.cxx @@ -0,0 +1,73 @@ +#include +#include + +#include +#include +#include + +using complex_t = std::complex; + +void dump(const std::vector& v, int nrows, int ncols, const std::string& msg="") +{ + std::cerr << msg << ": ("< v(size, 0); + //std::iota(v.begin(), v.end(), 0); + v[ncols+2] = 1.0; + dump(v, nrows, ncols, "v"); + + // Note: gpu is almost 10x SLOWER than CPU due to kernel load time! + // auto device = at::Device(at::kCPU); + auto device = at::Device(at::kCUDA); + + auto typ_options = at::TensorOptions().dtype(at::kComplexFloat); + // auto dev_options = typ_options.device(device); + + for (int axis = 0; axis < 2; ++ axis) { + at::Tensor src = at::from_blob(v.data(), {nrows, ncols}, typ_options); + dump(src, "src"); + + src = src.to(device); + // In pytorch dim=(0,) transforms along columns, ie follows + // numpy.fft convention. Both directions work on both CPU and + // CUDA. + at::Tensor dst = torch::fft::fft2(src, {}, {axis,}); + + // BUT, BEWARE that the underlying storage will NOT reflect + // logical row-major ordering. Indexing is as expected but + // memory returned by data_ptr() will reflect transpose + // optimizations. At the expense of a copy the contiguous() + // method provides expected row-major storage order. + dst = dst.contiguous().cpu(); + + dump(dst, "dst"); + + std::vector v2(size, 0); + memcpy(v2.data(), dst.data_ptr(), sizeof(complex_t)*size); + + std::cerr << "axis=" << axis << " dim=" << axis + << " shape=(" << src.size(0) << "," << src.size(1) << ")\n"; + dump(v2, nrows, ncols, "dft(v)"); + } + return 0; +} diff --git a/root/test/anode_loader.h b/root/test/anode_loader.h index 9b7d41351..cdf1d7b16 100644 --- a/root/test/anode_loader.h +++ b/root/test/anode_loader.h @@ -15,6 +15,7 @@ #include "WireCellIface/IAnodePlane.h" #include "WireCellIface/IFieldResponse.h" #include "WireCellIface/IWireSchema.h" +#include "WireCellIface/IDFT.h" #include #include @@ -59,6 +60,7 @@ std::vector anode_loader(std::string detector) PluginManager& pm = PluginManager::instance(); pm.add("WireCellSigProc"); pm.add("WireCellGen"); + pm.add("WireCellAux"); const std::string fr_tn = "FieldResponse"; const std::string ws_tn = "WireSchemaFile"; @@ -75,6 +77,11 @@ std::vector anode_loader(std::string detector) cfg["filename"] = ws_fname; icfg->configure(cfg); } + { + // If FftwDFT grows to be an IConfigurable, this needs to + // change to suit. + Factory::lookup("FftwDFT"); + } for (int ianode = 0; ianode < nanodes; ++ianode) { std::string tn = String::format("AnodePlane:%d", ianode); diff --git a/root/test/test_binneddiffusion.cxx b/root/test/test_binneddiffusion.cxx index 751a0664d..cd14504bd 100644 --- a/root/test/test_binneddiffusion.cxx +++ b/root/test/test_binneddiffusion.cxx @@ -1,3 +1,5 @@ +#include "WireCellAux/DftTools.h" + #include "WireCellGen/BinnedDiffusion.h" #include "WireCellIface/SimpleDepo.h" #include "WireCellUtil/ExecMon.h" @@ -30,11 +32,14 @@ struct Meta { ExecMon em; const char* name; - Meta(const char* name) + IDFT::pointer idft; + + Meta(const char* name, IDFT::pointer idft) //: theApp(new TApplication (name,0,0)) : canvas(new TCanvas("canvas", "canvas", 500, 500)) , em(name) , name(name) + , idft(idft) { print("["); } @@ -74,7 +79,7 @@ void test_track(Meta& meta, double charge, double track_time, const Ray& track_r const auto rbins = pimpos.region_binning(); const auto ibins = pimpos.impact_binning(); - Gen::BinnedDiffusion bd(pimpos, tbins, ndiffision_sigma, fluctuate); + Gen::BinnedDiffusion bd(pimpos, meta.idft, tbins, ndiffision_sigma, fluctuate); auto track_start = track_ray.first; auto track_dir = ray_unit(track_ray); @@ -231,18 +236,20 @@ int main(int argc, char* argv[]) { PluginManager& pm = PluginManager::instance(); pm.add("WireCellGen"); + pm.add("WireCellAux"); { auto rngcfg = Factory::lookup("Random"); auto cfg = rngcfg->default_configuration(); rngcfg->configure(cfg); } auto rng = Factory::lookup("Random"); + auto idft = Factory::lookup_tn("FftwDFT"); const char* me = argv[0]; TFile* rootfile = TFile::Open(Form("%s.root", me), "RECREATE"); - Meta meta(me); + Meta meta(me, idft); gStyle->SetOptStat(0); const double track_time = t0 + 10 * units::ns; diff --git a/root/test/test_convo.cxx b/root/test/test_convo.cxx index b6a64a7a1..d2bb9d343 100644 --- a/root/test/test_convo.cxx +++ b/root/test/test_convo.cxx @@ -1,3 +1,6 @@ +#include "WireCellAux/DftTools.h" +#include "WireCellUtil/NamedFactory.h" +#include "WireCellUtil/PluginManager.h" #include "WireCellUtil/Response.h" #include "WireCellUtil/Waveform.h" @@ -67,6 +70,10 @@ std::vector plot_wave(TCanvas& canvas, int padnum, std::string name, std: int main(int argc, char* argv[]) { + PluginManager& pm = PluginManager::instance(); + pm.add("WireCellAux"); + auto idft = Factory::lookup_tn("FftwDFT"); + if (argc < 2) { std::cerr << "This test requires an Wire Cell Field Response input file." << std::endl; return 0; @@ -126,16 +133,16 @@ int main(int argc, char* argv[]) } // frequency space - Waveform::compseq_t charge_spectrum = Waveform::dft(electrons); - Waveform::compseq_t raw_response_spectrum = Waveform::dft(raw_response); - Waveform::compseq_t response_spectrum = Waveform::dft(response); + Waveform::compseq_t charge_spectrum = Aux::fwd_r2c(idft, electrons); + Waveform::compseq_t raw_response_spectrum = Aux::fwd_r2c(idft, raw_response); + Waveform::compseq_t response_spectrum = Aux::fwd_r2c(idft, response); // convolve Waveform::compseq_t conv_spectrum(nticks, Waveform::complex_t(0.0, 0.0)); for (int ind = 0; ind < nticks; ++ind) { conv_spectrum[ind] = response_spectrum[ind] * charge_spectrum[ind]; } - Waveform::realseq_t conv = Waveform::idft(conv_spectrum); + Waveform::realseq_t conv = Aux::inv_c2r(idft, conv_spectrum); for (int ind = 0; ind < nticks; ++ind) { conv[ind] /= nticks; } diff --git a/root/test/test_convo_binning.cxx b/root/test/test_convo_binning.cxx index 46c1dba69..3a2fed9c9 100644 --- a/root/test/test_convo_binning.cxx +++ b/root/test/test_convo_binning.cxx @@ -1,11 +1,18 @@ // Test what happens with different choices of how we bin. #include "MultiPdf.h" + +#include "WireCellGen/RCResponse.h" + +#include "WireCellAux/DftTools.h" + +#include "WireCellUtil/NamedFactory.h" +#include "WireCellUtil/PluginManager.h" + #include "WireCellUtil/Units.h" #include "WireCellUtil/Waveform.h" #include "WireCellUtil/Binning.h" #include "WireCellUtil/Response.h" -#include "WireCellGen/RCResponse.h" #include "TGraph.h" #include "TH1F.h" @@ -75,6 +82,10 @@ struct Plotter { int main(int argc, char* argv[]) { + PluginManager& pm = PluginManager::instance(); + pm.add("WireCellAux"); + auto idft = Factory::lookup_tn("FftwDFT"); + Test::MultiPdf mpdf(argv[0]); Plotter p(mpdf); @@ -111,22 +122,22 @@ int main(int argc, char* argv[]) p.draw(fce, fbin_short, "fce", "Fine CE"); // convolve + rebin fine->coarse - auto fcc = linear_convolve(ffr, fce); + auto fcc = Aux::convolve(idft, ffr, fce); p.draw(fcc, fbin_long, "fcc", "Fine conv"); auto ccc2 = rebin(fcc, rebinfactor); p.draw(ccc2, cbin_long, "ccc2", "Coarse rebin conv"); - auto fccs = linear_convolve(ffrs, fce); + auto fccs = Aux::convolve(idft, ffrs, fce); p.draw(fccs, fbin_long, "fccs", "Fine conv shifted"); auto cccs2 = rebin(fccs, rebinfactor); p.draw(cccs2, cbin_long, "cccs2", "Coarse rebin conv shifted"); // rebin fine->coarse + convolve - auto ccc = linear_convolve(cfr, cce); + auto ccc = Aux::convolve(idft, cfr, cce); for (size_t ind=0; ind("FftwDFT"); + const std::vector gains = {7.8 * GUnit, 14.0 * GUnit}; const std::vector shapings = {1.0 * units::us, 2.0 * units::us}; @@ -123,7 +132,7 @@ int main(int argc, char* argv[]) const double tshape_us = shapings[ind] / units::us; auto tit = Form("Cold Electronics Response at %.0fus peaking", tshape_us); - draw_time_freq(pdf, res, tit, tbins); + draw_time_freq(pdf, idft, res, tit, tbins); } // Look at RC filter @@ -135,7 +144,7 @@ int main(int argc, char* argv[]) Waveform::realseq_t res = rc.generate(tbins); auto tit = "RC Response at 1ms time constant"; - draw_time_freq(pdf, res, tit, tbins); + draw_time_freq(pdf, idft, res, tit, tbins); } { Binning shifted(tbins.nbins(), tbins.min() + tick, tbins.max() + tick); @@ -144,7 +153,7 @@ int main(int argc, char* argv[]) Waveform::realseq_t res = rc.generate(shifted); auto tit = "RC Response at 1ms time constant (suppress delta)"; - draw_time_freq(pdf, res, tit, tbins); + draw_time_freq(pdf, idft, res, tit, tbins); } // Look at SysResp (Gaussian smear) @@ -152,7 +161,7 @@ int main(int argc, char* argv[]) Response::SysResp gaus; Waveform::realseq_t res = gaus.generate(tbins); auto tit = "Response Gaussian smear by default"; - draw_time_freq(pdf, res, tit, tbins); + draw_time_freq(pdf, idft, res, tit, tbins); } { double mag = 1.0; @@ -163,31 +172,36 @@ int main(int argc, char* argv[]) Response::SysResp gaus(tick, mag, smear); Waveform::realseq_t res = gaus.generate(ttt); auto tit = "Response Gaussian 2 us smear"; - draw_time_freq(pdf, res, tit, ttt); + draw_time_freq(pdf, idft, res, tit, ttt); } // do timing tests { TGraph* timings[4] = {new TGraph, new TGraph, new TGraph, new TGraph}; + TGraph* timings_1st[4] = {new TGraph, new TGraph, new TGraph, new TGraph}; // Some popular choices with powers-of-two sprinkled in - std::vector nsampleslist{128, 256, 400, 480, // protoDUNE U/V and W channels per plane - 512, - 800, // protoDUNE, sum of U or V channels for both faces - 960, // protoDUNE, sum of W channels (or wires) for both faces - 1024, - 1148, // N wires in U/V plane for protodune - 2048, - 2400, // number of channels in U or V in microboone - 2560, // DUNE, total APA channels - 3456, // number of channels in microboone's W - 4096, - 6000, // one choice of nticks for protoDUNE - 8192, - 8256, // total microboone channels - 9592, 9594, 9595, 9600, // various microboone readout lengths - 10000, // 5 ms at 2MHz readout - 10240, 16384}; + std::vector nsampleslist{ + 128, 256, // small powers of 2 + 400, 480, // protoDUNE U/V and W channels per plane + 512, + 800, // protoDUNE, all U or V channels + 960, // protoDUNE, all W channels + 1024, + 1148, // what's this? + 2000, // iceberg + 2048, + 2400, // number of channels in U or V in microboone + 2560, // DUNE, total APA channels + 3456, // number of channels in microboone's W + 4096, + 6000, // one choice of nticks for protoDUNE + 8192, + 8256, // total microboone channels + 9587, // prime near MB nticks + 9592, 9594, 9595, 9600, // various MB nticks + 10000, // 5 ms at 2MHz readout + 10240, 16384}; const int ntries = 1000; for (auto nsamps : nsampleslist) { Response::ColdElec ce(gains[1], shapings[1]); @@ -195,19 +209,34 @@ int main(int argc, char* argv[]) Waveform::realseq_t res = ce.generate(bins); Waveform::compseq_t spec; + + double fwd_time_1st = 0.0; + { + auto t1 = std::chrono::high_resolution_clock::now(); + spec = Aux::fwd_r2c(idft, res); + auto t2 = std::chrono::high_resolution_clock::now(); + fwd_time_1st += std::chrono::duration_cast(t2 - t1).count(); + } double fwd_time = 0.0; for (int itry = 0; itry < ntries; ++itry) { auto t1 = std::chrono::high_resolution_clock::now(); - spec = Waveform::dft(res); + spec = Aux::fwd_r2c(idft, res); auto t2 = std::chrono::high_resolution_clock::now(); fwd_time += std::chrono::duration_cast(t2 - t1).count(); } fwd_time /= ntries; + double rev_time_1st = 0.0; + { + auto t1 = std::chrono::high_resolution_clock::now(); + res = Aux::inv_c2r(idft, spec); + auto t2 = std::chrono::high_resolution_clock::now(); + rev_time_1st = std::chrono::duration_cast(t2 - t1).count(); + } double rev_time = 0.0; for (int itry = 0; itry < ntries; ++itry) { auto t1 = std::chrono::high_resolution_clock::now(); - res = Waveform::idft(spec); + res = Aux::inv_c2r(idft, spec); auto t2 = std::chrono::high_resolution_clock::now(); rev_time += std::chrono::duration_cast(t2 - t1).count(); } @@ -224,45 +253,93 @@ int main(int argc, char* argv[]) timings[1]->SetPoint(timings[1]->GetN(), nsamps, fwd_time / nsamps); timings[2]->SetPoint(timings[2]->GetN(), nsamps, rev_time); timings[3]->SetPoint(timings[3]->GetN(), nsamps, rev_time / nsamps); + + timings_1st[0]->SetPoint(timings_1st[0]->GetN(), nsamps, fwd_time_1st); + timings_1st[1]->SetPoint(timings_1st[1]->GetN(), nsamps, fwd_time_1st / (fwd_time)); + timings_1st[2]->SetPoint(timings_1st[2]->GetN(), nsamps, rev_time_1st); + timings_1st[3]->SetPoint(timings_1st[3]->GetN(), nsamps, rev_time_1st / (rev_time)); + } pdf.canvas.Clear(); pdf.canvas.Divide(1, 2); - auto text = new TText; { - auto pad = pdf.canvas.cd(1); - pad->SetGridx(); - pad->SetGridy(); - pad->SetLogx(); - auto graph = timings[0]; - auto frame = graph->GetHistogram(); - frame->SetTitle("Fwd/rev DFT timing (absolute)"); - frame->GetXaxis()->SetTitle("number of samples"); - frame->GetYaxis()->SetTitle("time (ns)"); - timings[0]->Draw("AL"); - timings[2]->Draw("L"); - for (int ind = 0; ind < graph->GetN(); ++ind) { - auto x = graph->GetX()[ind]; - auto y = graph->GetY()[ind]; - text->DrawText(x, y, Form("%.0f", x)); + auto text = new TText; + { + auto pad = pdf.canvas.cd(1); + pad->SetGridx(); + pad->SetGridy(); + pad->SetLogx(); + pad->SetLogy(); + auto graph = timings[0]; + auto frame = graph->GetHistogram(); + frame->SetTitle("Fwd/rev DFT timing (absolute)"); + frame->GetXaxis()->SetTitle("number of samples"); + frame->GetYaxis()->SetTitle("time [ns]"); + timings[0]->Draw("AL"); + timings[2]->Draw("L"); + for (int ind = 0; ind < graph->GetN(); ++ind) { + auto x = graph->GetX()[ind]; + auto y = graph->GetY()[ind]; + text->DrawText(x, y, Form("%.0f", x)); + } + } + + { + auto pad = pdf.canvas.cd(2); + pad->SetGridx(); + pad->SetGridy(); + pad->SetLogx(); + auto frame = timings[1]->GetHistogram(); + frame->SetTitle("Fwd/rev DFT timing (relative to size)"); + frame->GetXaxis()->SetTitle("number of samples"); + frame->GetYaxis()->SetTitle("time per sample [ns/samp]"); + timings[1]->Draw("AL"); + timings[3]->Draw("L"); } + pdf(); } + pdf.canvas.Clear(); + pdf.canvas.Divide(1, 2); + { - auto pad = pdf.canvas.cd(2); - pad->SetGridx(); - pad->SetGridy(); - pad->SetLogx(); - auto frame = timings[1]->GetHistogram(); - frame->SetTitle("Fwd/rev DFT timing (relative)"); - frame->GetXaxis()->SetTitle("number of samples"); - frame->GetYaxis()->SetTitle("time per sample (ns/samp)"); - timings[1]->Draw("AL"); - timings[3]->Draw("L"); + auto text = new TText; + { + auto pad = pdf.canvas.cd(1); + pad->SetGridx(); + pad->SetGridy(); + pad->SetLogx(); + pad->SetLogy(); + auto graph = timings_1st[0]; + auto frame = graph->GetHistogram(); + frame->SetTitle("fwd/rev DFT timing, ''cold'' (plan+exec) (absolute time)"); + frame->GetXaxis()->SetTitle("number of samples"); + frame->GetYaxis()->SetTitle("time [ns]"); + timings_1st[0]->Draw("AL"); + timings_1st[2]->Draw("L"); + for (int ind = 0; ind < graph->GetN(); ++ind) { + auto x = graph->GetX()[ind]; + auto y = graph->GetY()[ind]; + text->DrawText(x, y, Form("%.0f", x)); + } + } + + { + auto pad = pdf.canvas.cd(2); + pad->SetGridx(); + pad->SetGridy(); + pad->SetLogx(); + auto frame = timings_1st[1]->GetHistogram(); + frame->SetTitle("Fwd/rev DFT timing (cold/warm relative)"); + frame->GetXaxis()->SetTitle("number of samples"); + frame->GetYaxis()->SetTitle("relative time [ns / ns]"); + timings_1st[1]->Draw("AL"); + timings_1st[3]->Draw("L"); + } + pdf(); } - pdf(); } - return 0; } diff --git a/root/test/test_fft_speed.cxx b/root/test/test_fft_speed.cxx index 0dff0b06f..6ee66534c 100644 --- a/root/test/test_fft_speed.cxx +++ b/root/test/test_fft_speed.cxx @@ -1,3 +1,7 @@ +#include "WireCellAux/DftTools.h" +#include "WireCellUtil/NamedFactory.h" +#include "WireCellUtil/PluginManager.h" + #include "WireCellUtil/Waveform.h" #include "WireCellUtil/Units.h" #include "WireCellUtil/Response.h" @@ -25,6 +29,10 @@ const double GUnit = units::mV / units::fC; int main(int argc, char** argv) { + PluginManager& pm = PluginManager::instance(); + pm.add("WireCellAux"); + auto idft = Factory::lookup_tn("FftwDFT"); + int nInputs = 0; int nBegin = 0; int nEnd = 0; @@ -123,7 +131,8 @@ int main(int argc, char** argv) // fwd_time /= ntries; auto t1 = std::chrono::high_resolution_clock::now(); - Array::dft_cc(test_array, 0); + // Array::dft_cc(test_array, 0); + Aux::fwd(idft, test_array, 1); auto t2 = std::chrono::high_resolution_clock::now(); fwd_time = std::chrono::duration_cast(t2 - t1).count() / ntries; @@ -131,7 +140,9 @@ int main(int argc, char** argv) // for (int itry=0; itry(t4 - t3).count() / ntries; // } diff --git a/root/test/test_fieldresp.cxx b/root/test/test_fieldresp.cxx index 65e93102e..f0d133790 100644 --- a/root/test/test_fieldresp.cxx +++ b/root/test/test_fieldresp.cxx @@ -1,3 +1,5 @@ +#include "WireCellAux/DftTools.h" + #include "WireCellUtil/Testing.h" #include "WireCellUtil/Logging.h" @@ -34,13 +36,15 @@ int main(int argc, char* argv[]) /// WCT internals, normally user code does not need this { PluginManager& pm = PluginManager::instance(); + pm.add("WireCellAux"); + pm.add("WireCellSigProc"); auto ifrcfg = Factory::lookup("FieldResponse"); auto cfg = ifrcfg->default_configuration(); cfg["filename"] = frfname; ifrcfg->configure(cfg); } - + auto idft = Factory::lookup_tn("FftwDFT"); auto ifr = Factory::find("FieldResponse"); // Get full, "fine-grained" field responses defined at impact @@ -80,7 +84,7 @@ int main(int argc, char* argv[]) Response::ColdElec ce(14.0 * units::mV / units::fC, 2.0 * units::microsecond); auto ewave = ce.generate(tbins); Waveform::scale(ewave, 1.2 * 4096 / 2000.); - elec = Waveform::dft(ewave); + elec = Aux::fwd_r2c(idft, ewave); std::complex fine_period(fravg.period, 0); @@ -105,7 +109,8 @@ int main(int argc, char* argv[]) auto arr = Response::as_array(fravg.planes[ind]); // do FFT for response ... - Array::array_xxc c_data = Array::dft_rc(arr, 0); + // Array::array_xxc c_data = Array::dft_rc(arr, 0); + Array::array_xxc c_data = Aux::fwd(idft, arr.cast(), 1); int nrows = c_data.rows(); int ncols = c_data.cols(); @@ -115,7 +120,8 @@ int main(int argc, char* argv[]) } } - arr = Array::idft_cr(c_data, 0); + // arr = Array::idft_cr(c_data, 0); + arr = Aux::inv(idft, c_data, 1).real(); // figure out how to do fine ... shift (good ...) auto arr1 = arr.block(0, 0, nrows, 100); diff --git a/root/test/test_impactresponse.cxx b/root/test/test_impactresponse.cxx index 56e64faaa..e5cf76ee1 100644 --- a/root/test/test_impactresponse.cxx +++ b/root/test/test_impactresponse.cxx @@ -1,3 +1,9 @@ +#include "WireCellAux/DftTools.h" + +#include "WireCellIface/IConfigurable.h" +#include "WireCellIface/IFieldResponse.h" +#include "WireCellIface/IPlaneImpactResponse.h" + #include "WireCellUtil/PluginManager.h" #include "WireCellUtil/NamedFactory.h" #include "WireCellUtil/Logging.h" @@ -6,10 +12,6 @@ #include "WireCellUtil/Testing.h" #include "WireCellUtil/Response.h" -#include "WireCellIface/IConfigurable.h" -#include "WireCellIface/IFieldResponse.h" -#include "WireCellIface/IPlaneImpactResponse.h" - #include "MultiPdf.h" // local helper shared by a few tests #include "TH2F.h" #include "TLine.h" @@ -28,7 +30,8 @@ using namespace std; using spdlog::debug; using spdlog::error; -void plot_time(MultiPdf& mpdf, IPlaneImpactResponse::pointer pir, int iplane, Binning tbins, const std::string& name, +void plot_time(MultiPdf& mpdf, const IDFT::pointer& idft, + IPlaneImpactResponse::pointer pir, int iplane, Binning tbins, const std::string& name, const std::string& title) { // only show bins where we think the response is @@ -103,7 +106,8 @@ void plot_time(MultiPdf& mpdf, IPlaneImpactResponse::pointer pir, int iplane, Bi // continue; // } auto spec = ir->spectrum(); - auto wave = Waveform::idft(spec); + // auto wave = Waveform::idft(spec); + auto wave = Aux::inv_c2r(idft, spec); pitch += 0.001 * impact_dist; for (int ind = 0; ind < ntbins; ++ind) { const double time = tbins.center(ind); @@ -146,8 +150,10 @@ int main(int argc, const char* argv[]) Log::set_level("debug"); PluginManager& pm = PluginManager::instance(); + pm.add("WireCellAux"); pm.add("WireCellGen"); pm.add("WireCellSigProc"); + auto idft = Factory::lookup_tn("FftwDFT"); const int nticks = 9595; const double tick = 0.5 * units::us; @@ -234,10 +240,10 @@ int main(int argc, const char* argv[]) MultiPdf mpdf(out_basename.c_str()); for (int iplane = 0; iplane < 3; ++iplane) { auto pir = Factory::find_tn(pir_tns[iplane]); - plot_time(mpdf, pir, iplane, tbins, "fr", "Field Response"); + plot_time(mpdf, idft, pir, iplane, tbins, "fr", "Field Response"); auto pir_ele = Factory::find_tn(pir_ele_tns[iplane]); - plot_time(mpdf, pir_ele, iplane, tbins, "dr", "Detector Response"); + plot_time(mpdf, idft, pir_ele, iplane, tbins, "dr", "Detector Response"); } mpdf.close(); diff --git a/root/test/test_impactzipper.cxx b/root/test/test_impactzipper.cxx deleted file mode 100644 index 68355cbc0..000000000 --- a/root/test/test_impactzipper.cxx +++ /dev/null @@ -1,421 +0,0 @@ -#include "WireCellGen/ImpactZipper.h" -#include "WireCellGen/TrackDepos.h" -#include "WireCellGen/BinnedDiffusion.h" -#include "WireCellGen/TransportedDepo.h" -#include "WireCellGen/PlaneImpactResponse.h" -#include "WireCellUtil/ExecMon.h" -#include "WireCellUtil/Point.h" -#include "WireCellUtil/Binning.h" -#include "WireCellUtil/Testing.h" -#include "WireCellUtil/Response.h" - -#include "WireCellUtil/PluginManager.h" -#include "WireCellUtil/NamedFactory.h" -#include "WireCellIface/IRandom.h" -#include "WireCellIface/IConfigurable.h" -#include "WireCellIface/IFieldResponse.h" -#include "WireCellIface/IPlaneImpactResponse.h" - -#include "TCanvas.h" -#include "TFile.h" -#include "TLine.h" -#include "TStyle.h" -#include "TH2F.h" - -#include -#include - -using namespace WireCell; -using namespace std; - -int main(const int argc, char* argv[]) -{ - string track_types = "point"; - if (argc > 1) { - track_types = argv[1]; - } - cerr << "Using tracks type: \"" << track_types << "\"\n"; - - string response_file = "ub-10-half.json.bz2"; - if (argc > 2) { - response_file = argv[2]; - cerr << "Using Wire Cell field response file:\n" << response_file << endl; - } - else { - cerr << "No Wire Cell field response input file given, will try to use:\n" << response_file << endl; - } - - string out_basename = argv[0]; - if (argc > 3) { - out_basename = argv[3]; - } - - // here we do hard-wired configuration. User code should NEVER do - // this. - - PluginManager& pm = PluginManager::instance(); - pm.add("WireCellGen"); - pm.add("WireCellSigProc"); - { - auto rngcfg = Factory::lookup("Random"); - auto cfg = rngcfg->default_configuration(); - rngcfg->configure(cfg); - } - - const int nticks = 9595; - const double tick = 0.5 * units::us; - const double gain = 14.0 * units::mV / units::fC; - const double shaping = 2.0 * units::us; - - const double t0 = 0.0 * units::s; - const double readout_time = nticks * tick; - const double drift_speed = 1.0 * units::mm / units::us; // close, but not real - - const std::string er_tn = "ColdElecResponse", rc_tn = "RCResponse"; - - { // configure elecresponse - auto icfg = Factory::lookup_tn(er_tn); - auto cfg = icfg->default_configuration(); - cfg["gain"] = gain; - cfg["shaping"] = shaping; - cfg["nticks"] = nticks; - cerr << "Setting: " << cfg["nticks"].asInt() << " ticks\n"; - cfg["tick"] = tick; - cfg["start"] = t0; - icfg->configure(cfg); - } - { // configure rc response - auto icfg = Factory::lookup_tn(rc_tn); - auto cfg = icfg->default_configuration(); - cfg["nticks"] = nticks; - cfg["tick"] = tick; - cfg["start"] = t0; - icfg->configure(cfg); - } - { - auto icfg = Factory::lookup("FieldResponse"); - auto cfg = icfg->default_configuration(); - cfg["filename"] = response_file; - icfg->configure(cfg); - } - - std::vector pir_tns{"PlaneImpactResponse:U", "PlaneImpactResponse:V", "PlaneImpactResponse:W"}; - { // configure pirs - for (int iplane = 0; iplane < 3; ++iplane) { - auto icfg = Factory::lookup_tn(pir_tns[iplane]); - auto cfg = icfg->default_configuration(); - cfg["plane"] = iplane; - cfg["nticks"] = nticks; - cfg["tick"] = tick; - cfg["other_responses"][0] = er_tn; - cfg["other_responses"][1] = rc_tn; // double it so - cfg["other_responses"][2] = rc_tn; // we get RC^2 - icfg->configure(cfg); - } - } - - WireCell::ExecMon em(out_basename); - auto ifr = Factory::find_tn("FieldResponse"); - auto fr = ifr->field_response(); - - em("loaded response"); - - const char* uvw = "UVW"; - - // 1D garfield wires are all parallel - const double angle = 60 * units::degree; - const Vector upitch(0, -sin(angle), cos(angle)); - const Vector uwire(0, cos(angle), sin(angle)); - const Vector vpitch(0, sin(angle), cos(angle)); - const Vector vwire(0, cos(angle), -sin(angle)); - const Vector wpitch(0, 0, 1); - const Vector wwire(0, 1, 0); - - // FIXME: need to apply electronics response! - - // Origin where drift and diffusion meets field response. - Point field_origin(fr.origin, 0, 0); - cerr << "Field response origin: " << field_origin / units::mm << "mm\n"; - - // Describe the W collection plane - const int nwires = 2001; - const double wire_pitch = 3 * units::mm; - const int nregion_bins = 10; // fixme: this should come from the Response::Schema. - const double halfwireextent = wire_pitch * 0.5 * (nwires - 1); - cerr << "Max wire at pitch=" << halfwireextent << endl; - - std::vector uvw_pimpos{ - Pimpos(nwires, -halfwireextent, halfwireextent, uwire, upitch, field_origin, nregion_bins), - Pimpos(nwires, -halfwireextent, halfwireextent, vwire, vpitch, field_origin, nregion_bins), - Pimpos(nwires, -halfwireextent, halfwireextent, wwire, wpitch, field_origin, nregion_bins)}; - - // Digitization and time - Binning tbins(nticks, t0, t0 + readout_time); - - // Diffusion - const int ndiffision_sigma = 3.0; - bool fluctuate = false; // note, "point" negates this below - - // Generate some trivial tracks - const double stepsize = 0.003 * units::mm; - Gen::TrackDepos tracks(stepsize); - - // This is the number of ionized electrons for a MIP assumed by MB noise paper. - // note: with option "point" this is overridden below. - const double dqdx = 16000 * units::eplus / (3 * units::mm); - const double charge_per_depo = -(dqdx) *stepsize; - - const double event_time = t0 + 1 * units::ms; - const Point event_vertex(1.0 * units::m, 0 * units::m, 0 * units::mm); - - // mostly "prolonged" track in X direction - if (track_types.find("prolong") < track_types.size()) { - tracks.add_track(event_time, - Ray(event_vertex, event_vertex + Vector(1 * units::m, 0 * units::m, +10 * units::cm)), - charge_per_depo); - tracks.add_track(event_time, - Ray(event_vertex, event_vertex + Vector(1 * units::m, 0 * units::m, -10 * units::cm)), - charge_per_depo); - } - - // mostly "isochronous" track in Z direction, give spelling errors a break. :) - if (track_types.find("isoch") < track_types.size()) { - tracks.add_track(event_time, Ray(event_vertex, event_vertex + Vector(0, 0, 50 * units::mm)), charge_per_depo); - } - // "driftlike" track diagonal in space and drift time - if (track_types.find("driftlike") < track_types.size()) { - tracks.add_track(event_time, - Ray(event_vertex, event_vertex + Vector(60 * units::cm, 0 * units::m, 10.0 * units::mm)), - charge_per_depo); - } - - // make a + - if (track_types.find("plus") < track_types.size()) { - tracks.add_track(event_time, Ray(event_vertex, event_vertex + Vector(0, 0, +1 * units::m)), charge_per_depo); - tracks.add_track(event_time, Ray(event_vertex, event_vertex + Vector(0, 0, -1 * units::m)), charge_per_depo); - tracks.add_track(event_time, Ray(event_vertex, event_vertex + Vector(0, +1 * units::m, 0)), charge_per_depo); - tracks.add_track(event_time, Ray(event_vertex, event_vertex + Vector(0, -1 * units::m, 0)), charge_per_depo); - } - - // // make a . - if (track_types.find("point") < track_types.size()) { - fluctuate = false; - for (int i = 0; i < 6; i++) { - auto vt = event_vertex + Vector(0, 0, i * 0.06 * units::mm); - auto tt = event_time + i * 10.0 * units::us; - tracks.add_track(tt, Ray(vt, vt + Vector(0, 0, 0.1 * stepsize)), // force 1 point - -1.0 * units::eplus); - } - - /* tracks.add_track(event_time, */ - /* Ray(event_vertex, */ - /* event_vertex + Vector(0, 0, 0.1*stepsize)), // force 1 point */ - /* -1.0*units::eplus); */ - } - - em("made tracks"); - - // Get depos - auto depos = tracks.depos(); - - std::cerr << "got " << depos.size() << " depos from tracks\n"; - em("made depos"); - - TFile* rootfile = TFile::Open(Form("%s-uvw.root", out_basename.c_str()), "recreate"); - TCanvas* canvas = new TCanvas("c", "canvas", 1000, 1000); - gStyle->SetOptStat(0); - - std::string pdfname = argv[0]; - pdfname += ".pdf"; - canvas->Print((pdfname + "[").c_str(), "pdf"); - - IRandom::pointer rng = nullptr; - if (fluctuate) { - rng = Factory::lookup("Random"); - } - - for (int plane_id = 0; plane_id < 3; ++plane_id) { - em("start loop over planes"); - Pimpos& pimpos = uvw_pimpos[plane_id]; - - // add deposition to binned diffusion - Gen::BinnedDiffusion bindiff( - pimpos, tbins, ndiffision_sigma, rng, - Gen::BinnedDiffusion::ImpactDataCalculationStrategy::constant); // default is constant interpolation - em("made BinnedDiffusion"); - for (auto depo : depos) { - auto drifted = std::make_shared(depo, field_origin.x(), drift_speed); - - // In the real simulation these sigma are a function of - // drift time. Hard coded here with small values the - // resulting voltage peak due to "point" source should - // correspond to what is also shown on a per-impact - // "Detector Response" from util's test_impactresponse. - // Peak response of a delta function of current - // integrating over time to one electron charge would give - // 1eplus * 14mV/fC = 2.24 microvolt. - const double sigma_time = 1 * units::us; - const double sigma_pitch = 1.5 * units::mm; - - bool ok = bindiff.add(drifted, sigma_time, sigma_pitch); - if (!ok) { - std::cerr << "failed to add: t=" << drifted->time() / units::us << ", pt=" << drifted->pos() / units::mm - << std::endl; - } - Assert(ok); - - std::cerr << "depo:" - << " q=" << drifted->charge() / units::eplus << "ele" - << " time-T0=" << (drifted->time() - t0) / units::us << "us +/- " << sigma_time / units::us - << " us " - << " pt=" << drifted->pos() / units::mm << " mm\n"; - } - em("added track depositions"); - - auto ipir = Factory::find_tn(pir_tns[plane_id]); - - em("looked up " + pir_tns[plane_id]); - { - const Response::Schema::PlaneResponse* pr = fr.plane(plane_id); - const double pmax = 0.5 * ipir->pitch_range(); - const double pstep = std::abs(pr->paths[1].pitchpos - pr->paths[0].pitchpos); - const int npbins = 2.0 * pmax / pstep; - const int ntbins = pr->paths[0].current.size(); - - const double tmin = fr.tstart; - const double tmax = fr.tstart + fr.period * ntbins; - TH2F* hpir = new TH2F(Form("hfr%d", plane_id), Form("Field Response %c-plane", uvw[plane_id]), ntbins, tmin, - tmax, npbins, -pmax, pmax); - for (auto& path : pr->paths) { - const double cpitch = path.pitchpos; - for (size_t ic = 0; ic < path.current.size(); ++ic) { - const double ctime = fr.tstart + ic * fr.period; - const double charge = path.current[ic] * fr.period; - hpir->Fill(ctime, cpitch, -1 * charge / units::eplus); - } - } - hpir->SetZTitle("Induced charge [eles]"); - hpir->Write(); - - hpir->Draw("colz"); - if (track_types.find("point") < track_types.size()) { - hpir->GetXaxis()->SetRangeUser(70. * units::us, 100. * units::us); - hpir->GetYaxis()->SetRangeUser(-10. * units::mm, 10. * units::mm); - } - canvas->Update(); - // canvas->Print(Form("%s_%c_resp.png", out_basename.c_str(), uvw[plane_id])); - canvas->Print(pdfname.c_str(), "pdf"); - } - em("wrote and leaked response hist"); - - Gen::ImpactZipper zipper(ipir, bindiff); - em("made ImpactZipper"); - - // Set pitch range for plot y-axis - auto rbins = pimpos.region_binning(); - auto pmm = bindiff.pitch_range(ndiffision_sigma); - const int wbin0 = max(0, rbins.bin(pmm.first) - 40); - const int wbinf = min(rbins.nbins() - 1, rbins.bin(pmm.second) + 40); - const int nwbins = 1 + wbinf - wbin0; - - // Dead reckon - const int tbin0 = 3500, tbinf = 5500; - const int ntbins = tbinf - tbin0; - - std::map frame; - double tottot = 0.0; - for (int iwire = wbin0; iwire <= wbinf; ++iwire) { - auto wave = zipper.waveform(iwire); - auto tot = Waveform::sum(wave); - if (tot != 0.0) { - auto mm = std::minmax_element(wave.begin(), wave.end()); - cerr << "^ Wire " << iwire << " tot=" << tot / units::uV << " uV" - << " mm=[" << (*mm.first) / units::uV << "," << (*mm.second) / units::uV << "] uV " << endl; - } - - tottot += tot; - if (std::abs(iwire - 1000) <= 1) { // central wires for "point" - auto mm = std::minmax_element(wave.begin(), wave.end()); - std::cerr << "central wire: " << iwire << " mm=[" << (*mm.first) / units::microvolt << "," - << (*mm.second) / units::microvolt << "] uV\n"; - } - frame[iwire] = wave; - } - em("zipped through wires"); - cerr << "Tottot = " << tottot << endl; - Assert(tottot != 0.0); - - TH2F* hist = new TH2F(Form("h%d", plane_id), Form("Wire vs Tick %c-plane", uvw[plane_id]), ntbins, tbin0, - tbin0 + ntbins, nwbins, wbin0, wbin0 + nwbins); - hist->SetXTitle("tick"); - hist->SetYTitle("wire"); - hist->SetZTitle("Voltage [-#muV]"); - - std::cerr << nwbins << " wires: [" << wbin0 << "," << wbinf << "], " << ntbins << " ticks: [" << tbin0 << "," - << tbinf << "]\n"; - - em("created TH2F"); - for (auto wire : frame) { - const int iwire = wire.first; - Assert(rbins.inbounds(iwire)); - const Waveform::realseq_t& wave = wire.second; - // auto tot = Waveform::sum(wave); - // std::cerr << iwire << " tot=" << tot << std::endl; - for (int itick = tbin0; itick <= tbinf; ++itick) { - hist->Fill(itick + 0.1, iwire + 0.1, -1.0 * wave[itick] / units::microvolt); - } - } - - if (track_types.find("point") < track_types.size()) { - hist->GetXaxis()->SetRangeUser(3950, 4100); - hist->GetYaxis()->SetRangeUser(996, 1004); - } - if (track_types.find("isoch") < track_types.size()) { - hist->GetXaxis()->SetRangeUser(3900, 4000); - hist->GetYaxis()->SetRangeUser(995, 1020); - } - em("filled TH2F"); - hist->Write(); - em("wrote TH2F"); - hist->Draw("colz"); - canvas->SetRightMargin(0.15); - em("drew TH2F"); - std::vector lines; - auto trqs = tracks.tracks(); - for (size_t iline = 0; iline < trqs.size(); ++iline) { - auto trq = trqs[iline]; - const double time = get<0>(trq); - const Ray ray = get<1>(trq); - - // this need to subtract off the fr.origin is I think a bug, - // or at least a bookkeeping detail to ensconce somewhere. I - // think FR is taking the start of the path as the time - // origin. Something to check... - const int tick1 = tbins.bin(time + (ray.first.x() - fr.origin) / drift_speed); - const int tick2 = tbins.bin(time + (ray.second.x() - fr.origin) / drift_speed); - - const int wire1 = rbins.bin(pimpos.distance(ray.first)); - const int wire2 = rbins.bin(pimpos.distance(ray.second)); - - cerr << "digitrack: t=" << time << " ticks=[" << tick1 << "," << tick2 << "] wires=[" << wire1 << "," - << wire2 << "]\n"; - - const int fudge = 0; - TLine* line = new TLine(tick1 - fudge, wire1, tick2 - fudge, wire2); - line->Write(Form("l%c%d", uvw[plane_id], (int) iline)); - line->Draw(); - // canvas->Print(Form("%s_%c.png", out_basename.c_str(), uvw[plane_id])); - canvas->Print(pdfname.c_str(), "pdf"); - } - em("printed PNG canvases"); - em("end of PIR scope"); - - // canvas->Print("test_impactzipper.pdf","pdf"); - } - rootfile->Close(); - canvas->Print((pdfname + "]").c_str(), "pdf"); - em("done"); - - // cerr << em.summary() << endl; - return 0; -} diff --git a/root/test/test_interpolation.cxx b/root/test/test_interpolation.cxx index fad40692d..2a1e59edf 100644 --- a/root/test/test_interpolation.cxx +++ b/root/test/test_interpolation.cxx @@ -3,7 +3,6 @@ * Implementation in GaussianDiffusion for each charge depo */ #include "WireCellGen/GaussianDiffusion.h" -#include "WireCellGen/ImpactZipper.h" #include "WireCellGen/TrackDepos.h" #include "WireCellGen/BinnedDiffusion.h" #include "WireCellGen/TransportedDepo.h" diff --git a/root/test/test_misconfigure.cxx b/root/test/test_misconfigure.cxx index 1c7ea3af9..428dac9f7 100644 --- a/root/test/test_misconfigure.cxx +++ b/root/test/test_misconfigure.cxx @@ -1,4 +1,6 @@ +#include "WireCellAux/DftTools.h" #include "WireCellIface/IFrameFilter.h" +#include "WireCellIface/IDFT.h" #include "WireCellIface/IConfigurable.h" #include "WireCellIface/SimpleFrame.h" #include "WireCellIface/SimpleTrace.h" @@ -117,9 +119,13 @@ TH2F* plot_frame(MultiPdf& pdf, IFrame::pointer frame, std::string name, double int main(int argc, char* argv[]) { PluginManager& pm = PluginManager::instance(); + auto aux_pi = pm.add("WireCellAux"); + assert(aux_pi); pm.add("WireCellGen"); pm.add("WireCellRoot"); + auto idft = Factory::lookup_tn("FftwDFT"); + int nsamples = 50; double gain, shaping, tick; { @@ -140,9 +146,9 @@ int main(int argc, char* argv[]) auto resp = ce.generate(Binning(200, 0, 200 * tick)); auto resp2 = ce.generate(Binning(400, 0, 400 * tick)); auto resp3 = ce.generate(Binning(50, 0, 50 * tick)); - auto resp_spec = Waveform::dft(resp); - auto resp_spec2 = Waveform::dft(resp2); - auto resp_spec3 = Waveform::dft(resp3); + auto resp_spec = Aux::fwd_r2c(idft, resp); + auto resp_spec2 = Aux::fwd_r2c(idft, resp2); + auto resp_spec3 = Aux::fwd_r2c(idft, resp3); ITrace::vector q_traces; ITrace::vector out_traces; @@ -163,10 +169,10 @@ int main(int argc, char* argv[]) q_traces.push_back(std::make_shared(qchannel++, 0, q3)); q_traces.push_back(std::make_shared(qchannel++, 0, q4)); - auto e1 = linear_convolve(q1, resp); - auto e2 = linear_convolve(q2, resp); - auto e3 = linear_convolve(q3, resp); - auto e4 = linear_convolve(q4, resp); + auto e1 = Aux::convolve(idft, q1, resp); + auto e2 = Aux::convolve(idft, q2, resp); + auto e3 = Aux::convolve(idft, q3, resp); + auto e4 = Aux::convolve(idft, q4, resp); out_traces.push_back(std::make_shared(channel++, 0, e1)); out_traces.push_back(std::make_shared(channel++, 0, e2)); diff --git a/root/test/test_rcresponse.cxx b/root/test/test_rcresponse.cxx index fb6f63127..1e65c08a9 100644 --- a/root/test/test_rcresponse.cxx +++ b/root/test/test_rcresponse.cxx @@ -1,6 +1,11 @@ // Test RCResponse + #include "MultiPdf.h" // local helper shared by a few tests +#include "WireCellAux/DftTools.h" +#include "WireCellUtil/NamedFactory.h" +#include "WireCellUtil/PluginManager.h" + #include "WireCellUtil/Units.h" #include "WireCellUtil/Waveform.h" #include "WireCellGen/RCResponse.h" @@ -12,6 +17,9 @@ using namespace WireCell; int main(int argc, char* argv[]) { + PluginManager& pm = PluginManager::instance(); + pm.add("WireCellAux"); + auto idft = Factory::lookup_tn("FftwDFT"); Test::MultiPdf mpdf(argv[0]); const double tick = 0.5*units::us; @@ -28,7 +36,7 @@ int main(int argc, char* argv[]) const auto& wavep1 = rcr.waveform_samples(); // skip first which holds delta Waveform::realseq_t wave(wavep1.begin()+1, wavep1.end()); - auto spec = Waveform::dft(wave); + auto spec = Aux::fwd_r2c(idft, wave); auto mag = Waveform::magnitude(spec); TGraph* g = new TGraph(wave.size()); diff --git a/sig/inc/WireCellSig/Decon2DFilter.h b/sig/inc/WireCellSig/Decon2DFilter.h index 45d59a423..746f20491 100644 --- a/sig/inc/WireCellSig/Decon2DFilter.h +++ b/sig/inc/WireCellSig/Decon2DFilter.h @@ -6,6 +6,8 @@ #include "WireCellIface/IConfigurable.h" #include "WireCellIface/ITensorSetFilter.h" +#include "WireCellIface/IDFT.h" + #include "WireCellUtil/Logging.h" namespace WireCell { @@ -25,8 +27,9 @@ namespace WireCell { private: Log::logptr_t log; Configuration m_cfg; /// copy of configuration + IDFT::pointer m_dft; }; } // namespace Sig } // namespace WireCell -#endif // WIRECELLSIG_DECON2DFILTER \ No newline at end of file +#endif // WIRECELLSIG_DECON2DFILTER diff --git a/sig/inc/WireCellSig/Decon2DResponse.h b/sig/inc/WireCellSig/Decon2DResponse.h index 8e741f60d..f4cccce60 100644 --- a/sig/inc/WireCellSig/Decon2DResponse.h +++ b/sig/inc/WireCellSig/Decon2DResponse.h @@ -9,6 +9,8 @@ #include "WireCellIface/IAnodePlane.h" #include "WireCellIface/IChannelResponse.h" #include "WireCellIface/IFieldResponse.h" +#include "WireCellIface/IDFT.h" + #include "WireCellUtil/Logging.h" namespace WireCell { @@ -37,8 +39,10 @@ namespace WireCell { IChannelResponse::pointer m_cresp; IFieldResponse::pointer m_fresp; + + IDFT::pointer m_dft; }; } // namespace Sig } // namespace WireCell -#endif // WIRECELLSIG_DECON2DRESPONSE \ No newline at end of file +#endif // WIRECELLSIG_DECON2DRESPONSE diff --git a/sig/src/Decon2DFilter.cxx b/sig/src/Decon2DFilter.cxx index c3e23b833..73081d942 100644 --- a/sig/src/Decon2DFilter.cxx +++ b/sig/src/Decon2DFilter.cxx @@ -1,6 +1,15 @@ #include "WireCellSig/Decon2DFilter.h" #include "WireCellSig/Util.h" +#include "WireCellAux/SimpleTensorSet.h" +#include "WireCellAux/SimpleTensor.h" +#include "WireCellAux/Util.h" +#include "WireCellAux/TensUtil.h" +#include "WireCellAux/DftTools.h" + +#include "WireCellIface/ITensorSet.h" +#include "WireCellIface/IFilterWaveform.h" + #include "WireCellUtil/NamedFactory.h" #include "WireCellUtil/String.h" #include "WireCellUtil/Array.h" @@ -8,14 +17,6 @@ #include "WireCellUtil/FFTBestLength.h" #include "WireCellUtil/Exceptions.h" -#include "WireCellIface/ITensorSet.h" -#include "WireCellIface/IFilterWaveform.h" - -#include "WireCellAux/SimpleTensorSet.h" -#include "WireCellAux/SimpleTensor.h" -#include "WireCellAux/Util.h" -#include "WireCellAux/TensUtil.h" - WIRECELL_FACTORY(Decon2DFilter, WireCell::Sig::Decon2DFilter, WireCell::ITensorSetFilter, WireCell::IConfigurable) using namespace WireCell; @@ -28,11 +29,17 @@ Sig::Decon2DFilter::Decon2DFilter() Configuration Sig::Decon2DFilter::default_configuration() const { Configuration cfg; - + cfg["dft"] = "FftwDFT"; // type-name for the DFT to use return cfg; } -void Sig::Decon2DFilter::configure(const WireCell::Configuration &cfg) { m_cfg = cfg; } +void Sig::Decon2DFilter::configure(const WireCell::Configuration &cfg) +{ + std::string dft_tn = get(cfg, "dft", "FftwDFT"); + m_dft = Factory::find_tn(dft_tn); + + m_cfg = cfg; +} bool Sig::Decon2DFilter::operator()(const ITensorSet::pointer &in, ITensorSet::pointer &out) { @@ -111,7 +118,8 @@ bool Sig::Decon2DFilter::operator()(const ITensorSet::pointer &in, ITensorSet::p } // do the second round of inverse FFT on wire - Array::array_xxf tm_r_data = Array::idft_cr(c_data_afterfilter, 0); + Array::array_xxf tm_r_data = Aux::inv(m_dft, c_data_afterfilter, 1).real(); + Array::array_xxf r_data = tm_r_data.block(m_pad_nwires, 0, m_nwires, m_nticks); Sig::restore_baseline(r_data); @@ -162,4 +170,4 @@ bool Sig::Decon2DFilter::operator()(const ITensorSet::pointer &in, ITensorSet::p log->debug("Decon2DFilter: end"); return true; -} \ No newline at end of file +} diff --git a/sig/src/Decon2DResponse.cxx b/sig/src/Decon2DResponse.cxx index 7b3fd8f4a..5a997d4c4 100644 --- a/sig/src/Decon2DResponse.cxx +++ b/sig/src/Decon2DResponse.cxx @@ -1,5 +1,14 @@ #include "WireCellSig/Decon2DResponse.h" +#include "WireCellAux/SimpleTensorSet.h" +#include "WireCellAux/SimpleTensor.h" +#include "WireCellAux/Util.h" +#include "WireCellAux/TensUtil.h" +#include "WireCellAux/DftTools.h" + +#include "WireCellIface/ITensorSet.h" +#include "WireCellIface/IFilterWaveform.h" + #include "WireCellUtil/NamedFactory.h" #include "WireCellUtil/String.h" #include "WireCellUtil/Array.h" @@ -7,13 +16,6 @@ #include "WireCellUtil/FFTBestLength.h" #include "WireCellUtil/Exceptions.h" -#include "WireCellIface/ITensorSet.h" -#include "WireCellIface/IFilterWaveform.h" - -#include "WireCellAux/SimpleTensorSet.h" -#include "WireCellAux/SimpleTensor.h" -#include "WireCellAux/Util.h" -#include "WireCellAux/TensUtil.h" WIRECELL_FACTORY(Decon2DResponse, WireCell::Sig::Decon2DResponse, WireCell::ITensorSetFilter, WireCell::IConfigurable) @@ -27,7 +29,7 @@ Sig::Decon2DResponse::Decon2DResponse() Configuration Sig::Decon2DResponse::default_configuration() const { Configuration cfg; - + cfg["dft"] = "FftwDFT"; // type-name for the DFT to use return cfg; } @@ -56,6 +58,9 @@ void Sig::Decon2DResponse::configure(const WireCell::Configuration &cfg) if (!m_fresp) { THROW(ValueError() << errmsg{"Sig::Decon2DResponse::configure !m_fresp"}); } + + std::string dft_tn = get(cfg, "dft", "FftwDFT"); + m_dft = Factory::find_tn(dft_tn); } namespace { @@ -121,7 +126,7 @@ std::vector Sig::Decon2DResponse::init_overall_response(con Response::ColdElec ce(m_gain, m_shaping_time); auto ewave = ce.generate(tbins); Waveform::scale(ewave, m_inter_gain * m_ADC_mV * (-1)); - elec = Waveform::dft(ewave); + elec = Aux::fwd_r2c(m_dft, ewave); std::complex fine_period(fravg.period, 0); @@ -144,7 +149,9 @@ std::vector Sig::Decon2DResponse::init_overall_response(con auto arr = Response::as_array(fravg.planes[iplane], fine_nwires, fine_nticks); // do FFT for response ... - Array::array_xxc c_data = Array::dft_rc(arr, 0); + Array::array_xxc c_data = arr.cast(); + c_data = Aux::fwd(m_dft, c_data, 1); + int nrows = c_data.rows(); int ncols = c_data.cols(); @@ -154,7 +161,7 @@ std::vector Sig::Decon2DResponse::init_overall_response(con } } - arr = Array::idft_cr(c_data, 0); + arr = Aux::inv(m_dft, c_data, 1).real(); // figure out how to do fine ... shift (good ...) int fine_time_shift = m_fine_time_offset / fravg.period; @@ -262,7 +269,8 @@ bool Sig::Decon2DResponse::operator()(const ITensorSet::pointer &in, ITensorSet: log->debug("r_data: {} {}", r_data.rows(), r_data.cols()); // first round of FFT on time - auto c_data = Array::dft_rc(r_data, 0); + WireCell::Array::array_xxc c_data = r_data.cast(); + c_data = Aux::fwd(m_dft, c_data, 1); if (m_cresp) { log->debug("Decon2DResponse: applying ch-by-ch electronics response correction"); @@ -275,12 +283,12 @@ bool Sig::Decon2DResponse::operator()(const ITensorSet::pointer &in, ITensorSet: Response::ColdElec ce(m_gain, m_shaping_time); const auto ewave = ce.generate(tbins); - const WireCell::Waveform::compseq_t elec = Waveform::dft(ewave); + const WireCell::Waveform::compseq_t elec = Aux::fwd_r2c(m_dft, ewave); for (int irow = 0; irow != c_data.rows(); irow++) { Waveform::realseq_t tch_resp = m_cresp->channel_response(ch_arr[irow]); tch_resp.resize(m_fft_nticks, 0); - const WireCell::Waveform::compseq_t ch_elec = Waveform::dft(tch_resp); + const WireCell::Waveform::compseq_t ch_elec = Aux::fwd_r2c(m_dft, tch_resp); // FIXME figure this out // const int irow = och.wire + m_pad_nwires; @@ -298,7 +306,7 @@ bool Sig::Decon2DResponse::operator()(const ITensorSet::pointer &in, ITensorSet: log->trace("TRACE {}", __LINE__); // second round of FFT on wire - c_data = Array::dft_cc(c_data, 1); + c_data = Aux::fwd(m_dft, c_data, 0); // response part ... Array::array_xxf r_resp = Array::array_xxf::Zero(r_data.rows(), m_fft_nticks); @@ -310,9 +318,9 @@ bool Sig::Decon2DResponse::operator()(const ITensorSet::pointer &in, ITensorSet: log->trace("TRACE {}", __LINE__); // do first round FFT on the resposne on time - Array::array_xxc c_resp = Array::dft_rc(r_resp, 0); // do second round FFT on the response on wire - c_resp = Array::dft_cc(c_resp, 1); + Array::array_xxc c_resp = r_resp.cast(); + c_resp = Aux::fwd(m_dft, c_resp); // make ratio to the response and apply wire filter c_data = c_data / c_resp; @@ -337,10 +345,9 @@ bool Sig::Decon2DResponse::operator()(const ITensorSet::pointer &in, ITensorSet: log->trace("TRACE {}", __LINE__); // do the first round of inverse FFT on wire - c_data = Array::idft_cc(c_data, 1); - // do the second round of inverse FFT on time - r_data = Array::idft_cr(c_data, 0); + c_data = Aux::inv(m_dft, c_data); + r_data = c_data.real(); // do the shift in wire const int nrows = r_data.rows(); @@ -364,7 +371,8 @@ bool Sig::Decon2DResponse::operator()(const ITensorSet::pointer &in, ITensorSet: r_data.block(0, 0, nrows, time_shift) = arr2; r_data.block(0, time_shift, nrows, ncols - time_shift) = arr1; } - c_data = Array::dft_rc(r_data, 0); + c_data = Aux::fwd(m_dft, r_data.cast(), 1); + log->trace("TRACE {}", __LINE__); // Eigen to TensorSet @@ -398,4 +406,4 @@ bool Sig::Decon2DResponse::operator()(const ITensorSet::pointer &in, ITensorSet: log->debug("Decon2DResponse: end"); return true; -} \ No newline at end of file +} diff --git a/sigproc/inc/WireCellSigProc/L1SPFilter.h b/sigproc/inc/WireCellSigProc/L1SPFilter.h index bedaf7823..4ebe78751 100644 --- a/sigproc/inc/WireCellSigProc/L1SPFilter.h +++ b/sigproc/inc/WireCellSigProc/L1SPFilter.h @@ -8,9 +8,11 @@ #include "WireCellIface/IFrameFilter.h" #include "WireCellIface/IConfigurable.h" +#include "WireCellIface/IDFT.h" -#include "WireCellUtil/Interpolate.h" #include "WireCellIface/SimpleTrace.h" +#include "WireCellUtil/Interpolate.h" + namespace WireCell { namespace SigProc { @@ -38,6 +40,7 @@ namespace WireCell { private: Configuration m_cfg; + IDFT::pointer m_dft; double m_gain; double m_shaping; diff --git a/sigproc/inc/WireCellSigProc/Microboone.h b/sigproc/inc/WireCellSigProc/Microboone.h index 8e286870b..a15b1d067 100644 --- a/sigproc/inc/WireCellSigProc/Microboone.h +++ b/sigproc/inc/WireCellSigProc/Microboone.h @@ -4,14 +4,16 @@ #ifndef WIRECELLSIGPROC_MICROBOONE #define WIRECELLSIGPROC_MICROBOONE -#include "WireCellUtil/Waveform.h" -#include "WireCellUtil/Bits.h" +#include "WireCellSigProc/Diagnostics.h" + #include "WireCellIface/IChannelFilter.h" #include "WireCellIface/IConfigurable.h" #include "WireCellIface/IChannelNoiseDatabase.h" #include "WireCellIface/IAnodePlane.h" +#include "WireCellIface/IDFT.h" -#include "WireCellSigProc/Diagnostics.h" +#include "WireCellUtil/Waveform.h" +#include "WireCellUtil/Bits.h" namespace WireCell { namespace SigProc { @@ -26,14 +28,18 @@ namespace WireCell { bool NoisyFilterAlg(WireCell::Waveform::realseq_t& spec, float min_rms, float max_rms); std::vector > SignalProtection(WireCell::Waveform::realseq_t& sig, - const WireCell::Waveform::compseq_t& respec, int res_offset, + const WireCell::Waveform::compseq_t& respec, + const IDFT::pointer& dft, + int res_offset, int pad_f, int pad_b, float upper_decon_limit = 0.02, float decon_lf_cutoff = 0.08, float upper_adc_limit = 15, float protection_factor = 5.0, float min_adc_limit = 50); bool Subtract_WScaling(WireCell::IChannelFilter::channel_signals_t& chansig, const WireCell::Waveform::realseq_t& medians, const WireCell::Waveform::compseq_t& respec, int res_offset, - std::vector >& rois, float upper_decon_limit1 = 0.08, + std::vector >& rois, + const IDFT::pointer& dft, + float upper_decon_limit1 = 0.08, float roi_min_max_ratio = 0.8, float rms_threshold = 0.); // hold common config stuff @@ -54,6 +60,7 @@ namespace WireCell { std::string m_anode_tn, m_noisedb_tn; IAnodePlane::pointer m_anode; IChannelNoiseDatabase::pointer m_noisedb; + IDFT::pointer m_dft; }; /** Microboone style coherent noise subtraction. @@ -135,6 +142,7 @@ namespace WireCell { private: std::string m_anode_tn; IAnodePlane::pointer m_anode; + IDFT::pointer m_dft; double m_threshold; int m_window; int m_nbins; diff --git a/sigproc/inc/WireCellSigProc/OmniChannelNoiseDB.h b/sigproc/inc/WireCellSigProc/OmniChannelNoiseDB.h index 38154a367..bec3bbc4a 100644 --- a/sigproc/inc/WireCellSigProc/OmniChannelNoiseDB.h +++ b/sigproc/inc/WireCellSigProc/OmniChannelNoiseDB.h @@ -4,6 +4,7 @@ #include "WireCellIface/IChannelNoiseDatabase.h" #include "WireCellIface/IConfigurable.h" #include "WireCellIface/IAnodePlane.h" +#include "WireCellIface/IDFT.h" #include "WireCellIface/IFieldResponse.h" #include "WireCellIface/WirePlaneId.h" @@ -167,6 +168,7 @@ namespace WireCell { std::unordered_map m_response_cache; Log::logptr_t log; + IDFT::pointer m_dft; }; } // namespace SigProc diff --git a/sigproc/inc/WireCellSigProc/OmnibusSigProc.h b/sigproc/inc/WireCellSigProc/OmnibusSigProc.h index 4c597621b..23d77f48e 100644 --- a/sigproc/inc/WireCellSigProc/OmnibusSigProc.h +++ b/sigproc/inc/WireCellSigProc/OmnibusSigProc.h @@ -1,13 +1,17 @@ #ifndef WIRECELLSIGPROC_OMNIBUSSIGPROC #define WIRECELLSIGPROC_OMNIBUSSIGPROC +#include "WireCellAux/Logger.h" + #include "WireCellIface/IFrameFilter.h" #include "WireCellIface/IConfigurable.h" #include "WireCellIface/IAnodePlane.h" +#include "WireCellIface/IDFT.h" #include "WireCellIface/IWaveform.h" + #include "WireCellUtil/Waveform.h" #include "WireCellUtil/Array.h" -#include "WireCellAux/Logger.h" + #include @@ -61,9 +65,18 @@ namespace WireCell { void decon_2D_looseROI_debug_mode(int plane); - // save data into the out frame and collect the indices - void save_data(ITrace::vector& itraces, IFrame::trace_list_t& indices, int plane, - const std::vector& perwire_rmses, IFrame::trace_summary_t& threshold); + // for debugging, check current state of working data + void check_data(int plane, const std::string& loglabel); + + // Copy elements from m_r_data, mess with them, and store + // result into traces. Update indices. Fixme: best if we + // were to factor saving and munging! + void save_data(ITrace::vector& itraces, + IFrame::trace_list_t& indices, + int plane, + const std::vector& perwire_rmses, + IFrame::trace_summary_t& threshold, + const std::string& loglabel); // save ROI into the out frame (set use_roi_debug_mode=true) void save_roi(ITrace::vector& itraces, IFrame::trace_list_t& indices, int plane, @@ -229,7 +242,9 @@ namespace WireCell { bool m_sparse; size_t m_count{0}; + int m_verbose{0}; + IDFT::pointer m_dft; }; } // namespace SigProc } // namespace WireCell @@ -237,5 +252,5 @@ namespace WireCell { #endif // Local Variables: // mode: c++ -// c-basic-offset: 2 +// c-basic-offset: 4 // End: diff --git a/sigproc/inc/WireCellSigProc/Protodune.h b/sigproc/inc/WireCellSigProc/Protodune.h index fbb607e00..2fa443b00 100644 --- a/sigproc/inc/WireCellSigProc/Protodune.h +++ b/sigproc/inc/WireCellSigProc/Protodune.h @@ -5,14 +5,16 @@ #ifndef WIRECELLSIGPROC_PROTODUNE #define WIRECELLSIGPROC_PROTODUNE -#include "WireCellUtil/Waveform.h" -#include "WireCellUtil/Bits.h" +#include "WireCellSigProc/Diagnostics.h" + #include "WireCellIface/IChannelFilter.h" #include "WireCellIface/IConfigurable.h" #include "WireCellIface/IChannelNoiseDatabase.h" #include "WireCellIface/IAnodePlane.h" +#include "WireCellIface/IDFT.h" -#include "WireCellSigProc/Diagnostics.h" +#include "WireCellUtil/Waveform.h" +#include "WireCellUtil/Bits.h" namespace WireCell { namespace SigProc { @@ -20,39 +22,16 @@ namespace WireCell { bool LinearInterpSticky(WireCell::Waveform::realseq_t& signal, std::vector >& st_ranges, float stky_sig_like_val, float stky_sig_like_rms); - bool FftInterpSticky(WireCell::Waveform::realseq_t& signal, std::vector >& st_ranges); - bool FftShiftSticky(WireCell::Waveform::realseq_t& signal, double toffset, + bool FftInterpSticky(const IDFT::pointer& dft, + WireCell::Waveform::realseq_t& signal, std::vector >& st_ranges); + bool FftShiftSticky(const IDFT::pointer& dft, + WireCell::Waveform::realseq_t& signal, double toffset, std::vector >& st_ranges); - bool FftScaling(WireCell::Waveform::realseq_t& signal, int nsamples); - - // hold common config stuff - class ConfigFilterBase : public WireCell::IConfigurable { - public: - ConfigFilterBase(const std::string& anode = "AnodePlane", - const std::string& noisedb = "OmniChannelNoiseDB"); - virtual ~ConfigFilterBase(); + bool FftScaling(const IDFT::pointer& dft, + WireCell::Waveform::realseq_t& signal, int nsamples); - // IConfigurable configuration interface - virtual void configure(const WireCell::Configuration& config); - virtual WireCell::Configuration default_configuration() const; - - // FIXME: this method needs to die. - void set_channel_noisedb(WireCell::IChannelNoiseDatabase::pointer ndb) { m_noisedb = ndb; } - - protected: - std::string m_anode_tn, m_noisedb_tn; - IAnodePlane::pointer m_anode; - IChannelNoiseDatabase::pointer m_noisedb; - }; - - /** Microboone/ProtoDUNE style noise subtraction. + /** ProtoDUNE style noise subtraction. * - * Fixme: in principle, this class could be general purpose - * for other detectors. However, it uses the functions above - * which hard code microboone-isms. If those - * microboone/protodune-specific parameters can be pulled out to a - * higher layer then this class can become generic and move - * outside of this file. */ class StickyCodeMitig : public WireCell::IChannelFilter, public WireCell::IConfigurable { @@ -77,6 +56,7 @@ namespace WireCell { std::string m_anode_tn, m_noisedb_tn; IAnodePlane::pointer m_anode; IChannelNoiseDatabase::pointer m_noisedb; + IDFT::pointer m_dft; std::map > m_extra_stky; // ch to extra sticky codes float m_stky_sig_like_val; @@ -84,7 +64,7 @@ namespace WireCell { int m_stky_max_len; }; - class OneChannelNoise : public WireCell::IChannelFilter, public ConfigFilterBase { + class OneChannelNoise : public WireCell::IChannelFilter, public WireCell::IConfigurable { public: OneChannelNoise(const std::string& anode_tn = "AnodePlane", const std::string& noisedb = "OmniChannelNoiseDB"); @@ -102,8 +82,12 @@ namespace WireCell { WireCell::Configuration default_configuration() const; private: + std::string m_anode_tn, m_noisedb_tn; Diagnostics::Partial m_check_partial; // at least need to expose them to configuration std::map m_resmp; // ch => orignal smp input + IAnodePlane::pointer m_anode; + IChannelNoiseDatabase::pointer m_noisedb; + IDFT::pointer m_dft; }; // A relative gain correction based on David Adam's pulse area calibration diff --git a/sigproc/inc/WireCellSigProc/SimpleChannelNoiseDB.h b/sigproc/inc/WireCellSigProc/SimpleChannelNoiseDB.h index ae75db9e3..5efef1834 100644 --- a/sigproc/inc/WireCellSigProc/SimpleChannelNoiseDB.h +++ b/sigproc/inc/WireCellSigProc/SimpleChannelNoiseDB.h @@ -2,6 +2,8 @@ #define WIRECELLSIGPROC_SIMPLECHANNELNOISEDB #include "WireCellIface/IChannelNoiseDatabase.h" +#include "WireCellIface/IConfigurable.h" +#include "WireCellIface/IDFT.h" #include "WireCellUtil/Waveform.h" #include "WireCellUtil/Units.h" @@ -14,7 +16,8 @@ namespace WireCell { namespace SigProc { - class SimpleChannelNoiseDB : public WireCell::IChannelNoiseDatabase { + class SimpleChannelNoiseDB : public WireCell::IConfigurable, + public WireCell::IChannelNoiseDatabase { public: /// Create a simple channel noise DB for digitized waveforms /// with the given size and number of samples. Default is for @@ -22,6 +25,10 @@ namespace WireCell { SimpleChannelNoiseDB(double tick = 0.5 * units::us, int nsamples = 9600); virtual ~SimpleChannelNoiseDB(); + /// IConfigurable + virtual void configure(const WireCell::Configuration& config); + virtual WireCell::Configuration default_configuration() const; + // IChannelNoiseDatabase virtual double sample_time() const { return m_tick; } @@ -143,6 +150,8 @@ namespace WireCell { std::vector m_channel_groups; channel_group_t m_bad_channels; + + IDFT::pointer m_dft; }; } // namespace SigProc diff --git a/sigproc/src/L1SPFilter.cxx b/sigproc/src/L1SPFilter.cxx index a021d7c2a..0b0a42343 100644 --- a/sigproc/src/L1SPFilter.cxx +++ b/sigproc/src/L1SPFilter.cxx @@ -1,13 +1,16 @@ #include "WireCellSigProc/L1SPFilter.h" +#include "WireCellAux/DftTools.h" +#include "WireCellAux/FrameTools.h" + #include "WireCellIface/SimpleFrame.h" #include "WireCellIface/IFieldResponse.h" -#include "WireCellUtil/NamedFactory.h" -#include "WireCellAux/FrameTools.h" - #include "WireCellRess/LassoModel.h" #include "WireCellRess/ElasticNetModel.h" + +#include "WireCellUtil/NamedFactory.h" + #include #include @@ -54,7 +57,7 @@ void L1SPFilter::init_resp() Response::ColdElec ce(m_gain, m_shaping); auto ewave = ce.generate(tbins); Waveform::scale(ewave, m_postgain * m_ADC_mV * (-1)); // ADC to electron ... - elec = Waveform::dft(ewave); + elec = Aux::fwd_r2c(m_dft, ewave); std::complex fine_period(fravg.period, 0); @@ -62,8 +65,8 @@ void L1SPFilter::init_resp() WireCell::Waveform::realseq_t resp_V = fravg.planes[1].paths[0].current; WireCell::Waveform::realseq_t resp_W = fravg.planes[2].paths[0].current; - auto spectrum_V = WireCell::Waveform::dft(resp_V); - auto spectrum_W = WireCell::Waveform::dft(resp_W); + auto spectrum_V = Aux::fwd_r2c(m_dft, resp_V); + auto spectrum_W = Aux::fwd_r2c(m_dft, resp_W); WireCell::Waveform::scale(spectrum_V, elec); WireCell::Waveform::scale(spectrum_W, elec); @@ -72,8 +75,8 @@ void L1SPFilter::init_resp() WireCell::Waveform::scale(spectrum_W, fine_period); // Now this response is ADC for 1 electron . - resp_V = WireCell::Waveform::idft(spectrum_V); - resp_W = WireCell::Waveform::idft(spectrum_W); + resp_V = Aux::inv_c2r(m_dft, spectrum_V); + resp_W = Aux::inv_c2r(m_dft, spectrum_W); // convolute with V and Y average responses ... double intrinsic_time_offset = fravg.origin / fravg.speed; @@ -153,6 +156,8 @@ WireCell::Configuration L1SPFilter::default_configuration() const cfg["fine_time_offset"] = m_fine_time_offset; cfg["coarse_time_offset"] = m_coarse_time_offset; + cfg["dft"] = "FftwDFT"; // type-name for the DFT to use + return cfg; } @@ -167,6 +172,9 @@ void L1SPFilter::configure(const WireCell::Configuration& cfg) m_fine_time_offset = get(cfg, "fine_time_offset", m_fine_time_offset); m_coarse_time_offset = get(cfg, "coarse_time_offset", m_coarse_time_offset); + + std::string dft_tn = get(cfg, "dft", "FftwDFT"); + m_dft = Factory::find_tn(dft_tn); } bool L1SPFilter::operator()(const input_pointer& in, output_pointer& out) @@ -197,19 +205,6 @@ bool L1SPFilter::operator()(const input_pointer& in, output_pointer& out) // l1_col_scale << " " << l1_ind_scale << std::endl; init_resp(); - // std::cout << (*lin_V)(0*units::us) << " " << (*lin_W)(0*units::us) << std::endl; - // std::cout << (*lin_V)(1*units::us) << " " << (*lin_W)(1*units::us) << std::endl; - // for (size_t i=0; i!=resp_V.size(); i++){ - // std::cout << (i*fravg.period - intrinsic_time_offset - m_coarse_time_offset + m_fine_time_offset)/units::us << " - // " << resp_V.at(i) << " " << resp_W.at(i) << " " << ewave.at(i) << std::endl; - //} - // std::complex fine_period(fravg.period,0); - // int fine_nticks = Response::as_array(fravg.planes[0]).cols(); - // Waveform::realseq_t ftbins(fine_nticks); - // for (int i=0;i!=fine_nticks;i++){ - // ftbins.at(i) = i * fravg.period; - //} - auto adctraces = Aux::tagged_traces(in, adctag); auto sigtraces = Aux::tagged_traces(in, sigtag); diff --git a/sigproc/src/Microboone.cxx b/sigproc/src/Microboone.cxx index 52c45add5..7fbf8ada8 100644 --- a/sigproc/src/Microboone.cxx +++ b/sigproc/src/Microboone.cxx @@ -5,6 +5,8 @@ #include "WireCellSigProc/Microboone.h" #include "WireCellSigProc/Derivations.h" +#include "WireCellAux/DftTools.h" + #include "WireCellUtil/NamedFactory.h" #include @@ -48,7 +50,9 @@ double filter_low_loose(double freq) { return 1 - exp(-pow(freq / 0.005, 2)); } bool Microboone::Subtract_WScaling(WireCell::IChannelFilter::channel_signals_t& chansig, const WireCell::Waveform::realseq_t& medians, const WireCell::Waveform::compseq_t& respec, int res_offset, - std::vector >& rois, float decon_limit1, float roi_min_max_ratio, + std::vector >& rois, + const IDFT::pointer& dft, + float decon_limit1, float roi_min_max_ratio, float rms_threshold) { double ave_coef = 0; @@ -134,7 +138,7 @@ bool Microboone::Subtract_WScaling(WireCell::IChannelFilter::channel_signals_t& } // do the deconvolution with a very loose low-frequency filter - WireCell::Waveform::compseq_t signal_roi_freq = WireCell::Waveform::dft(signal_roi); + WireCell::Waveform::compseq_t signal_roi_freq = Aux::fwd_r2c(dft, signal_roi); WireCell::Waveform::shrink(signal_roi_freq, respec); for (size_t i = 0; i != signal_roi_freq.size(); i++) { double freq; @@ -148,7 +152,7 @@ bool Microboone::Subtract_WScaling(WireCell::IChannelFilter::channel_signals_t& std::complex factor = filter_time(freq) * filter_low_loose(freq); signal_roi_freq.at(i) = signal_roi_freq.at(i) * factor; } - WireCell::Waveform::realseq_t signal_roi_decon = WireCell::Waveform::idft(signal_roi_freq); + WireCell::Waveform::realseq_t signal_roi_decon = Aux::inv_c2r(dft, signal_roi_freq); if (rms_threshold) { std::pair temp = Derivations::CalcRMS(signal_roi_decon); @@ -267,7 +271,9 @@ bool Microboone::Subtract_WScaling(WireCell::IChannelFilter::channel_signals_t& } std::vector > Microboone::SignalProtection(WireCell::Waveform::realseq_t& medians, - const WireCell::Waveform::compseq_t& respec, int res_offset, + const WireCell::Waveform::compseq_t& respec, + const IDFT::pointer& dft, + int res_offset, int pad_f, int pad_b, float upper_decon_limit, float decon_lf_cutoff, float upper_adc_limit, float protection_factor, float min_adc_limit) @@ -342,7 +348,7 @@ std::vector > Microboone::SignalProtection(WireCell::Waveform:: if (respec.size() > 0 && (respec.at(0).real() != 1 || respec.at(0).imag() != 0) && res_offset != 0) { // std::cout << nbin << std::endl; - WireCell::Waveform::compseq_t medians_freq = WireCell::Waveform::dft(medians); + WireCell::Waveform::compseq_t medians_freq = Aux::fwd_r2c(dft, medians); WireCell::Waveform::shrink(medians_freq, respec); for (size_t i = 0; i != medians_freq.size(); i++) { @@ -357,7 +363,7 @@ std::vector > Microboone::SignalProtection(WireCell::Waveform:: std::complex factor = filter_time(freq) * filter_low(freq, decon_lf_cutoff); medians_freq.at(i) = medians_freq.at(i) * factor; } - WireCell::Waveform::realseq_t medians_decon = WireCell::Waveform::idft(medians_freq); + WireCell::Waveform::realseq_t medians_decon = Aux::inv_c2r(dft, medians_freq); temp = Derivations::CalcRMS(medians_decon); mean = temp.first; @@ -394,58 +400,6 @@ std::vector > Microboone::SignalProtection(WireCell::Waveform:: } } } - - // // second-level decon ... - // medians_freq = WireCell::Waveform::dft(medians); - // WireCell::Waveform::realseq_t respec_time = WireCell::Waveform::idft(respec); - // for (size_t i=0;i!=respec_time.size();i++){ - // if (respec_time.at(i)<0) respec_time.at(i) = 0; - // } - // WireCell::Waveform::compseq_t respec_freq = WireCell::Waveform::dft(respec_time); - // WireCell::Waveform::shrink(medians_freq,respec_freq); - // for (size_t i=0;i!=medians_freq.size();i++){ - // double freq; - // // assuming 2 MHz digitization - // if (i factor = filter_time(freq)*filter_low(freq, decon_lf_cutoff); - // medians_freq.at(i) = medians_freq.at(i) * factor; - // } - // medians_decon = WireCell::Waveform::idft(medians_freq); - - // temp = Derivations::CalcRMS(medians_decon); - // mean = temp.first; - // rms = temp.second; - - // // if (protection_factor*rms > upper_decon_limit){ - // limit = protection_factor*rms; - // // }else{ - // // limit = upper_decon_limit; - // // } - - // for (int j=0;j!=nbin;j++) { - // float content = medians_decon.at(j); - // if ((content-mean)>limit){ - // int time_bin = j + res_offset; - // if (time_bin >= nbin) time_bin -= nbin; - // // medians.at(time_bin) = 0; - // signalsBool.at(time_bin) = true; - // // add the front and back padding - // for (int k=0;k!=pad_b;k++){ - // int bin = time_bin+k+1; - // if (bin > nbin-1) bin = nbin-1; - // signalsBool.at(bin) = true; - // } - // for (int k=0;k!=pad_f;k++){ - // int bin = time_bin-k-1; - // if (bin <0) { bin = 0; } - // signalsBool.at(bin) = true; - // } - // } - // } } // { @@ -483,75 +437,6 @@ std::vector > Microboone::SignalProtection(WireCell::Waveform:: } } - // // use ROI to get a new waveform - // WireCell::Waveform::realseq_t medians_roi(nbin,0); - // for (auto roi: rois){ - // const int bin0 = std::max(roi.front()-1, 0); - // const int binf = std::min(roi.back()+1, nbin-1); - // const double m0 = medians[bin0]; - // const double mf = medians[binf]; - // const double roi_run = binf - bin0; - // const double roi_rise = mf - m0; - // for (auto bin : roi) { - // const double m = m0 + (bin - bin0)/roi_run*roi_rise; - // medians_roi.at(bin) = medians.at(bin) - m; - // } - // } - // // do the deconvolution with a very loose low-frequency filter - // WireCell::Waveform::compseq_t medians_roi_freq = WireCell::Waveform::dft(medians_roi); - // WireCell::Waveform::shrink(medians_roi_freq,respec); - // for (size_t i=0;i!=medians_roi_freq.size();i++){ - // double freq; - // // assuming 2 MHz digitization - // if (i factor = filter_time(freq)*filter_low_loose(freq); - // medians_roi_freq.at(i) = medians_roi_freq.at(i) * factor; - // } - // WireCell::Waveform::realseq_t medians_roi_decon = WireCell::Waveform::idft(medians_roi_freq); - - // // judge if a roi is good or not ... - // //shift things back properly - // for (auto roi: rois){ - // const int bin0 = std::max(roi.front()-1, 0); - // const int binf = std::min(roi.back()+1, nbin-1); - // flag_replace[roi.front()] = false; - - // double max_val = 0; - // double min_val = 0; - // // double max_adc_val=0; - // // double min_adc_val=0; - - // for (int i=bin0; i<=binf; i++){ - // int time_bin = i-res_offset; - // if (time_bin <0) time_bin += nbin; - // if (time_bin >=nbin) time_bin -= nbin; - - // if (i==bin0){ - // max_val = medians_roi_decon.at(time_bin); - // min_val = medians_roi_decon.at(time_bin); - // // max_adc_val = medians.at(i); - // // min_adc_val = medians.at(i); - // }else{ - // if (medians_roi_decon.at(time_bin) > max_val) max_val = medians_roi_decon.at(time_bin); - // if (medians_roi_decon.at(time_bin) < min_val) min_val = medians_roi_decon.at(time_bin); - // // if (medians.at(i) > max_adc_val) max_adc_val = medians.at(i); - // // if (medians.at(i) < min_adc_val) min_adc_val = medians.at(i); - // } - // } - - // //std::cout << "Xin: " << upper_decon_limit1 << std::endl; - // // if ( max_val > upper_decon_limit1) - // // if ( max_val > 0.04 && fabs(min_val) < 0.6*max_val) - // //if (max_val > 0.06 && fabs(min_val) < 0.6*max_val) - // if (max_val > 0.06) - // flag_replace[roi.front()] = true; - // } - // } - // Replace medians for above regions with interpolation on values // just outside each region. for (auto roi : rois) { @@ -885,6 +770,9 @@ void Microboone::ConfigFilterBase::configure(const WireCell::Configuration& cfg) m_anode = Factory::find_tn(m_anode_tn); m_noisedb_tn = get(cfg, "noisedb", m_noisedb_tn); m_noisedb = Factory::find_tn(m_noisedb_tn); + std::string dft_tn = get(cfg, "dft", "FftwDFT"); + m_dft = Factory::find_tn(dft_tn); + // std::cerr << "ConfigFilterBase: \n" << cfg << "\n"; } WireCell::Configuration Microboone::ConfigFilterBase::default_configuration() const @@ -892,6 +780,7 @@ WireCell::Configuration Microboone::ConfigFilterBase::default_configuration() co Configuration cfg; cfg["anode"] = m_anode_tn; cfg["noisedb"] = m_noisedb_tn; + cfg["dft"] = "FftwDFT"; // type-name for the DFT to use return cfg; } @@ -945,7 +834,8 @@ WireCell::Waveform::ChannelMaskMap Microboone::CoherentNoiseSub::apply(channel_s // do the signal protection and adaptive baseline std::vector > rois = - Microboone::SignalProtection(medians, respec, res_offset, pad_f, pad_b, decon_limit, decon_lf_cutoff, adc_limit, + Microboone::SignalProtection(medians, respec, m_dft, + res_offset, pad_f, pad_b, decon_limit, decon_lf_cutoff, adc_limit, protection_factor, min_adc_limit); // if (achannel == 3840){ @@ -959,7 +849,9 @@ WireCell::Waveform::ChannelMaskMap Microboone::CoherentNoiseSub::apply(channel_s // << medians.at(101) << std::endl; // calculate the scaling coefficient and subtract - Microboone::Subtract_WScaling(chansig, medians, respec, res_offset, rois, decon_limit1, roi_min_max_ratio, + Microboone::Subtract_WScaling(chansig, medians, respec, res_offset, rois, + m_dft, + decon_limit1, roi_min_max_ratio, m_rms_threshold); // WireCell::IChannelFilter::signal_t& signal = chansig.begin()->second; @@ -1045,7 +937,7 @@ WireCell::Waveform::ChannelMaskMap Microboone::OneChannelNoise::apply(int ch, si } } - auto spectrum = WireCell::Waveform::dft(signal); + auto spectrum = Aux::fwd_r2c(m_dft, signal); // std::cerr << "OneChannelNoise: "<(cfg, "dft", "FftwDFT"); + m_dft = Factory::find_tn(dft_tn); // std::cerr << "OneChannelStatus: \n" << cfg << "\n"; } WireCell::Configuration Microboone::OneChannelStatus::default_configuration() const @@ -1349,6 +1243,7 @@ WireCell::Configuration Microboone::OneChannelStatus::default_configuration() co cfg["Nbins"] = m_nbins; cfg["Cut"] = m_cut; cfg["anode"] = m_anode_tn; + cfg["dft"] = "FftwDFT"; // type-name for the DFT to use return cfg; } @@ -1414,7 +1309,7 @@ bool Microboone::OneChannelStatus::ID_lf_noisy(signal_t& sig) const // temp_sig.at(i)=i; // } // do FFT - Waveform::compseq_t sig_freq = Waveform::dft(temp_sig); + Waveform::compseq_t sig_freq = Aux::fwd_r2c(m_dft, temp_sig); for (int i = 0; i != m_nbins; i++) { content += abs(sig_freq.at(i + 1)); } diff --git a/sigproc/src/OmniChannelNoiseDB.cxx b/sigproc/src/OmniChannelNoiseDB.cxx index 2435d3a4d..e17dfa33c 100644 --- a/sigproc/src/OmniChannelNoiseDB.cxx +++ b/sigproc/src/OmniChannelNoiseDB.cxx @@ -1,4 +1,5 @@ #include "WireCellSigProc/OmniChannelNoiseDB.h" +#include "WireCellAux/DftTools.h" #include "WireCellUtil/Response.h" #include "WireCellUtil/NamedFactory.h" @@ -63,6 +64,7 @@ WireCell::Configuration OmniChannelNoiseDB::default_configuration() const /// These must be provided cfg["groups"] = Json::arrayValue; cfg["channel_info"] = Json::arrayValue; + cfg["dft"] = "FftwDFT"; // type-name for the DFT to use return cfg; } @@ -181,7 +183,8 @@ OmniChannelNoiseDB::shared_filter_t OmniChannelNoiseDB::parse_rcrc(Json::Value j // auto signal = rcres.generate(WireCell::Binning(m_nsamples, 0, m_nsamples*m_tick)); auto signal = rcres.generate(WireCell::Waveform::Domain(0, m_nsamples * m_tick), m_nsamples); - Waveform::compseq_t spectrum = Waveform::dft(signal); + Waveform::compseq_t spectrum = Aux::fwd_r2c(m_dft, signal); + // get the square of it because there are two RC filters Waveform::compseq_t spectrum2 = spectrum; // Waveform::scale(spectrum2,spectrum); @@ -255,8 +258,8 @@ OmniChannelNoiseDB::shared_filter_t OmniChannelNoiseDB::get_reconfig(double from auto to_sig = to_ce.generate(WireCell::Waveform::Domain(0, m_nsamples * m_tick), m_nsamples); auto from_sig = from_ce.generate(WireCell::Waveform::Domain(0, m_nsamples * m_tick), m_nsamples); - auto to_filt = Waveform::dft(to_sig); - auto from_filt = Waveform::dft(from_sig); + auto to_filt = Aux::fwd_r2c(m_dft, to_sig); + auto from_filt = Aux::fwd_r2c(m_dft, from_sig); // auto from_filt_sum = Waveform::sum(from_filt); // auto to_filt_sum = Waveform::sum(to_filt); @@ -316,7 +319,7 @@ OmniChannelNoiseDB::shared_filter_t OmniChannelNoiseDB::parse_response(Json::Val waveform[ind] += current[ind]; } } - auto spectrum = WireCell::Waveform::dft(waveform); + auto spectrum = Aux::fwd_r2c(m_dft, waveform); auto ret = std::make_shared(spectrum); m_response_cache[wpid.ident()] = ret; return ret; @@ -338,7 +341,7 @@ OmniChannelNoiseDB::shared_filter_t OmniChannelNoiseDB::parse_response(Json::Val waveform[ind] = jwave[ind].asFloat(); } - auto spectrum = WireCell::Waveform::dft(waveform); + auto spectrum = Aux::fwd_r2c(m_dft, waveform); auto ret = std::make_shared(spectrum); m_waveform_cache[id] = ret; return ret; @@ -579,6 +582,9 @@ void OmniChannelNoiseDB::configure(const WireCell::Configuration& cfg) std::string fr_tn = get(cfg, "field_response", "FieldResponse"); m_fr = Factory::find_tn(fr_tn); + std::string dft_tn = get(cfg, "dft", "FftwDFT"); + m_dft = Factory::find_tn(dft_tn); + // WARNING: this assumes channel numbers count from 0 with no gaps! // int nchans = m_anode->channels().size(); // std::cerr << "noise database with " << nchans << " channels\n"; diff --git a/sigproc/src/OmnibusSigProc.cxx b/sigproc/src/OmnibusSigProc.cxx index ef1b5f490..9e76e9d07 100644 --- a/sigproc/src/OmnibusSigProc.cxx +++ b/sigproc/src/OmnibusSigProc.cxx @@ -1,10 +1,9 @@ +#include "ROI_formation.h" +#include "ROI_refinement.h" + #include "WireCellSigProc/OmnibusSigProc.h" -#include "WireCellUtil/NamedFactory.h" -#include "WireCellUtil/Exceptions.h" -#include "WireCellUtil/String.h" -#include "WireCellUtil/FFTBestLength.h" -#include "WireCellUtil/Waveform.h" +#include "WireCellAux/DftTools.h" #include "WireCellIface/SimpleFrame.h" #include "WireCellIface/SimpleTrace.h" @@ -13,8 +12,11 @@ #include "WireCellIface/IFilterWaveform.h" #include "WireCellIface/IChannelResponse.h" -#include "ROI_formation.h" -#include "ROI_refinement.h" +#include "WireCellUtil/NamedFactory.h" +#include "WireCellUtil/Exceptions.h" +#include "WireCellUtil/String.h" +#include "WireCellUtil/FFTBestLength.h" +#include "WireCellUtil/Waveform.h" #include "WireCellUtil/NamedFactory.h" @@ -121,19 +123,23 @@ void OmnibusSigProc::configure(const WireCell::Configuration& config) m_coarse_time_offset = get(config, "ctoffset", m_coarse_time_offset); m_anode_tn = get(config, "anode", m_anode_tn); + std::string dft_tn = get(config, "dft", "FftwDFT"); + m_dft = Factory::find_tn(dft_tn); + m_verbose = get(config, "verbose", 0); + // m_nticks = get(config,"nticks",m_nticks); if (!config["nticks"].isNull()) { - log->warn("no setting \"nticks\", ignoring value {}", config["nticks"].asInt()); + log->warn("config: no setting \"nticks\", ignoring value {}", config["nticks"].asInt()); } // m_period = get(config,"period",m_period); if (!config["period"].isNull()) { - log->warn("no setting \"period\", ignoring value {}", config["period"].asDouble()); + log->warn("config: no setting \"period\", ignoring value {}", config["period"].asDouble()); } m_fft_flag = get(config, "fft_flag", m_fft_flag); if (m_fft_flag) { m_fft_flag = 0; - log->warn("fft_flag option is broken, will use native array sizes"); + log->warn("config: fft_flag option is broken, will use native array sizes"); } m_elecresponse_tn = get(config, "elecresponse", m_elecresponse_tn); m_gain = get(config, "gain", m_gain); @@ -213,7 +219,7 @@ void OmnibusSigProc::configure(const WireCell::Configuration& config) // but we have plane-major order so make a temporary collection. IChannel::vector plane_channels[3]; std::stringstream ss; - ss << "internal channel map for tags: gauss:\"" << m_gauss_tag << "\", wiener:\"" << m_wiener_tag + ss << "config: internal channel map for tags: gauss:\"" << m_gauss_tag << "\", wiener:\"" << m_wiener_tag << "\", frame:\"" << m_frame_tag << "\"\n"; // fixme: this loop is now available as Aux::plane_channels() @@ -262,6 +268,8 @@ WireCell::Configuration OmnibusSigProc::default_configuration() const { Configuration cfg; cfg["anode"] = m_anode_tn; + cfg["dft"] = "FftwDFT"; // type-name for the DFT to use + cfg["verbose"] = 0; // larger is more more logging cfg["ftoffset"] = m_fine_time_offset; cfg["ctoffset"] = m_coarse_time_offset; // cfg["nticks"] = m_nticks; @@ -357,7 +365,8 @@ void OmnibusSigProc::load_data(const input_pointer& in, int plane) auto const& charges = trace->charge(); const int ntbins = std::min((int) charges.size(), m_nticks); for (int qind = 0; qind < ntbins; ++qind) { - m_r_data[plane](och.wire + m_pad_nwires[plane], tbin + qind) = charges[qind]; + const float q = charges[qind]; + m_r_data[plane](och.wire + m_pad_nwires[plane], tbin + qind) = q; } // ensure dead channels are indeed dead ... @@ -374,19 +383,42 @@ void OmnibusSigProc::load_data(const input_pointer& in, int plane) } } } - log->debug("plane index: {} input data identifies {} bad regions", plane, nbad); + log->debug("call={} load plane index: {}, ntraces={}, input bad regions: {}", + m_count, plane, traces->size(), nbad); + check_data(plane, "load data"); } // used in sparsifying below. Could use C++17 lambdas.... static bool ispositive(float x) { return x > 0.0; } static bool isZero(float x) { return x == 0.0; } -void OmnibusSigProc::save_data(ITrace::vector& itraces, IFrame::trace_list_t& indices, int plane, - const std::vector& perwire_rmses, IFrame::trace_summary_t& threshold) +void OmnibusSigProc::check_data(int iplane, const std::string& loglabel) +{ + if (!m_verbose) { return; } + + std::stringstream ss; + auto& arr = m_r_data[iplane]; + + log->debug("data: plane={}, sum={}, mean={}, min={}, max={} \"{}\"", + iplane, + arr.sum(), arr.mean(), arr.minCoeff(), arr.maxCoeff(), + loglabel); +} + +void OmnibusSigProc::save_data( + ITrace::vector& itraces, + IFrame::trace_list_t& indices, + int plane, + const std::vector& perwire_rmses, + IFrame::trace_summary_t& threshold, + const std::string& loglabel) { + check_data(plane, loglabel + " before save"); + // reuse this temporary vector to hold charge for a channel. ITrace::ChargeSequence charge(m_nticks, 0.0); + double qloss = 0.0; double qtot = 0.0; for (auto och : m_channel_range[plane]) { // ordered by osp channel @@ -396,10 +428,18 @@ void OmnibusSigProc::save_data(ITrace::vector& itraces, IFrame::trace_list_t& in const float q = m_r_data[plane](och.wire, itick); // charge.at(itick) = q > 0.0 ? q : 0.0; // charge.at(itick) = q ; - if (m_use_roi_debug_mode) + if (m_use_roi_debug_mode) { charge.at(itick) = q; // debug mode: save all decons - else - charge.at(itick) = q > 0.0 ? q : 0.0; // default mode: only save positive + } + else { // nominal: threshold at zero. + if (q > 0.0) { + charge.at(itick) = q; + } + else { + charge.at(itick) = 0.0; + qloss += q; + } + } } { auto& bad = m_cmm["bad"]; @@ -407,6 +447,7 @@ void OmnibusSigProc::save_data(ITrace::vector& itraces, IFrame::trace_list_t& in if (badit != bad.end()) { for (auto bad : badit->second) { for (int itick = bad.first; itick < bad.second; ++itick) { + qloss += charge.at(itick); charge.at(itick) = 0.0; } } @@ -454,15 +495,17 @@ void OmnibusSigProc::save_data(ITrace::vector& itraces, IFrame::trace_list_t& in // debug if (indices.empty()) { - log->debug("save_data plane index: {} empty", plane); + log->debug("call={} {} save plane index: {} empty", + m_count, loglabel, plane); } else { - const int nadded = indices.back() - indices.front() + 1; - log->debug("save_data plane index: {}, Qtot={} " - "added {} traces to total {} indices:[{},{}]", - plane, qtot, - nadded, indices.size(), indices.front(), indices.back()); + log->debug("call={} save plane index: {}, Qtot={} Qloss={}, " + "{} indices spanning [{},{}] \"{}\"", + m_count, plane, qtot, qloss, + indices.size(), indices.front(), indices.back(), + loglabel); } + check_data(plane, loglabel + " after save"); } // save ROI into the out frame @@ -711,14 +754,14 @@ void OmnibusSigProc::init_overall_response(IFrame::pointer frame) int tbinmin = *mme.first; int tbinmax = *mme.second; m_nticks = tbinmax - tbinmin; - log->debug("OmnibusSigProc: nticks={} tbinmin={} tbinmax={}", m_nticks, tbinmin, tbinmax); + log->debug("call={} init nticks={} tbinmin={} tbinmax={}", m_count, m_nticks, tbinmin, tbinmax); if (m_fft_flag == 0) { m_fft_nticks = m_nticks; } else { m_fft_nticks = fft_best_length(m_nticks); - log->debug("OmnibusSigProc: enlarge window from {} to {}", m_nticks, m_fft_nticks); + log->debug("call={} init enlarge window from {} to {}", m_count, m_nticks, m_fft_nticks); } // @@ -740,7 +783,8 @@ void OmnibusSigProc::init_overall_response(IFrame::pointer frame) } else { m_fft_nwires[i] = fft_best_length(m_nwires[i] + fravg.planes[0].paths.size() - 1, 1); - log->debug("OmnibusSigProc: enlarge wire number in plane {} from {} to {}", i, m_nwires[i], + log->debug("call={} init enlarge wire number in plane {} from {} to {}", + m_count, i, m_nwires[i], m_fft_nwires[i]); } m_pad_nwires[i] = (m_fft_nwires[i] - m_nwires[i]) / 2; @@ -756,7 +800,7 @@ void OmnibusSigProc::init_overall_response(IFrame::pointer frame) // auto ewave = ce.generate(tbins); auto ewave = (*m_elecresponse).waveform_samples(tbins); Waveform::scale(ewave, m_inter_gain * m_ADC_mV * (-1)); - elec = Waveform::dft(ewave); + elec = Aux::fwd_r2c(m_dft, ewave); std::complex fine_period(fravg.period, 0); @@ -782,18 +826,27 @@ void OmnibusSigProc::init_overall_response(IFrame::pointer frame) for (int iplane = 0; iplane < 3; ++iplane) { auto arr = Response::as_array(fravg.planes[iplane], fine_nwires, fine_nticks); + + int nrows = 0; + int ncols = 0; + // do FFT for response ... - Array::array_xxc c_data = Array::dft_rc(arr, 0); - int nrows = c_data.rows(); - int ncols = c_data.cols(); + { + Array::array_xxc c_data = arr.cast(); + c_data = Aux::fwd(m_dft, c_data, 1); + + nrows = c_data.rows(); + ncols = c_data.cols(); - for (int irow = 0; irow < nrows; ++irow) { - for (int icol = 0; icol < ncols; ++icol) { - c_data(irow, icol) = c_data(irow, icol) * elec.at(icol) * fine_period; + for (int irow = 0; irow < nrows; ++irow) { + for (int icol = 0; icol < ncols; ++icol) { + c_data(irow, icol) = c_data(irow, icol) * elec.at(icol) * fine_period; + } } - } - arr = Array::idft_cr(c_data, 0); + c_data = Aux::inv(m_dft, c_data, 1); + arr = c_data.real(); + } // figure out how to do fine ... shift (good ...) int fine_time_shift = m_fine_time_offset / fravg.period; @@ -883,39 +936,28 @@ void OmnibusSigProc::decon_2D_init(int plane) { // data part ... // first round of FFT on time - m_c_data[plane] = Array::dft_rc(m_r_data[plane], 0); + m_c_data[plane] = Aux::fwd(m_dft, m_r_data[plane].cast(), 1); // now apply the ch-by-ch response ... if (!m_per_chan_resp.empty()) { - log->debug("OmnibusSigProc: applying ch-by-ch electronics response correction"); + log->debug("call={} applying ch-by-ch electronics response correction", m_count); auto cr = Factory::find_tn(m_per_chan_resp); auto cr_bins = cr->channel_response_binning(); if (cr_bins.binsize() != m_period) { - log->critical("OmnibusSigProc::decon_2D_init: channel response size mismatch"); + log->critical("call={} decon_2D_init: channel response size mismatch", m_count); THROW(ValueError() << errmsg{"OmnibusSigProc::decon_2D_init: channel response size mismatch"}); } - // starndard electronics response ... - // WireCell::Binning tbins(m_nticks, 0-m_period/2., m_nticks*m_period-m_period/2.); - // Response::ColdElec ce(m_gain, m_shaping_time); - // temporary hack ... - // float scaling = 1./(1e-9*0.5/1.13312); - // WireCell::Binning tbins(m_nticks, (-5-0.5)*m_period, (m_nticks-5-0.5)*m_period-m_period); - // Response::ColdElec ce(m_gain*scaling, m_shaping_time); - //// this is moved into wirecell.sigproc.main production of - //// microboone-channel-responses-v1.json.bz2 WireCell::Binning tbins(m_fft_nticks, cr_bins.min(), cr_bins.min() + m_fft_nticks * m_period); - // Response::ColdElec ce(m_gain, m_shaping_time); - // const auto ewave = ce.generate(tbins); auto ewave = (*m_elecresponse).waveform_samples(tbins); - const WireCell::Waveform::compseq_t elec = Waveform::dft(ewave); + const WireCell::Waveform::compseq_t elec = Aux::fwd_r2c(m_dft, ewave); for (auto och : m_channel_range[plane]) { // const auto& ch_resp = cr->channel_response(och.ident); Waveform::realseq_t tch_resp = cr->channel_response(och.ident); tch_resp.resize(m_fft_nticks, 0); - const WireCell::Waveform::compseq_t ch_elec = Waveform::dft(tch_resp); + const WireCell::Waveform::compseq_t ch_elec = Aux::fwd_r2c(m_dft, tch_resp); const int irow = och.wire + m_pad_nwires[plane]; for (int icol = 0; icol != m_c_data[plane].cols(); icol++) { @@ -931,7 +973,7 @@ void OmnibusSigProc::decon_2D_init(int plane) } // second round of FFT on wire - m_c_data[plane] = Array::dft_cc(m_c_data[plane], 1); + m_c_data[plane] = Aux::fwd(m_dft, m_c_data[plane], 0); // response part ... Array::array_xxf r_resp = Array::array_xxf::Zero(m_r_data[plane].rows(), m_fft_nticks); @@ -942,9 +984,9 @@ void OmnibusSigProc::decon_2D_init(int plane) } // do first round FFT on the resposne on time - Array::array_xxc c_resp = Array::dft_rc(r_resp, 0); // do second round FFT on the response on wire - c_resp = Array::dft_cc(c_resp, 1); + Array::array_xxc c_resp = r_resp.cast(); + c_resp = Aux::fwd(m_dft, c_resp); // make ratio to the response and apply wire filter m_c_data[plane] = m_c_data[plane] / c_resp; @@ -968,10 +1010,8 @@ void OmnibusSigProc::decon_2D_init(int plane) } // do the first round of inverse FFT on wire - m_c_data[plane] = Array::idft_cc(m_c_data[plane], 1); - // do the second round of inverse FFT on time - m_r_data[plane] = Array::idft_cr(m_c_data[plane], 0); + m_r_data[plane] = Aux::inv(m_dft, m_c_data[plane]).real(); // do the shift in wire const int nrows = m_r_data[plane].rows(); @@ -995,7 +1035,8 @@ void OmnibusSigProc::decon_2D_init(int plane) m_r_data[plane].block(0, 0, nrows, time_shift) = arr2; m_r_data[plane].block(0, time_shift, nrows, ncols - time_shift) = arr1; } - m_c_data[plane] = Array::dft_rc(m_r_data[plane], 0); + m_c_data[plane] = Aux::fwd(m_dft, m_r_data[plane].cast(), 1); + } void OmnibusSigProc::decon_2D_ROI_refine(int plane) @@ -1016,7 +1057,8 @@ void OmnibusSigProc::decon_2D_ROI_refine(int plane) } // do the second round of inverse FFT on wire - Array::array_xxf tm_r_data = Array::idft_cr(c_data_afterfilter, 0); + Array::array_xxf tm_r_data = Aux::inv(m_dft, c_data_afterfilter, 1).real(); + m_r_data[plane] = tm_r_data.block(m_pad_nwires[plane], 0, m_nwires[plane], m_nticks); restore_baseline(m_r_data[plane]); } @@ -1057,7 +1099,8 @@ void OmnibusSigProc::decon_2D_tightROI(int plane) } // do the second round of inverse FFT on wire - Array::array_xxf tm_r_data = Array::idft_cr(c_data_afterfilter, 0); + Array::array_xxf tm_r_data = Aux::inv(m_dft, c_data_afterfilter, 1).real(); + m_r_data[plane] = tm_r_data.block(m_pad_nwires[plane], 0, m_nwires[plane], m_nticks); restore_baseline(m_r_data[plane]); } @@ -1099,7 +1142,8 @@ void OmnibusSigProc::decon_2D_tighterROI(int plane) } // do the second round of inverse FFT on wire - Array::array_xxf tm_r_data = Array::idft_cr(c_data_afterfilter, 0); + Array::array_xxf tm_r_data = Aux::inv(m_dft, c_data_afterfilter, 1).real(); + m_r_data[plane] = tm_r_data.block(m_pad_nwires[plane], 0, m_nwires[plane], m_nticks); restore_baseline(m_r_data[plane]); } @@ -1176,7 +1220,8 @@ void OmnibusSigProc::decon_2D_looseROI(int plane) } // do the second round of inverse FFT on wire - Array::array_xxf tm_r_data = Array::idft_cr(c_data_afterfilter, 0); + Array::array_xxf tm_r_data = Aux::inv(m_dft, c_data_afterfilter, 1).real(); + m_r_data[plane] = tm_r_data.block(m_pad_nwires[plane], 0, m_nwires[plane], m_nticks); restore_baseline(m_r_data[plane]); } @@ -1221,7 +1266,8 @@ void OmnibusSigProc::decon_2D_looseROI_debug_mode(int plane) } // do the second round of inverse FFT on wire - Array::array_xxf tm_r_data = Array::idft_cr(c_data_afterfilter, 0); + Array::array_xxf tm_r_data = Aux::inv(m_dft, c_data_afterfilter, 1).real(); + m_r_data[plane] = tm_r_data.block(m_pad_nwires[plane], 0, m_nwires[plane], m_nticks); restore_baseline(m_r_data[plane]); } @@ -1282,7 +1328,7 @@ void OmnibusSigProc::decon_2D_hits(int plane) } // do the second round of inverse FFT on wire - Array::array_xxf tm_r_data = Array::idft_cr(c_data_afterfilter, 0); + Array::array_xxf tm_r_data = Aux::inv(m_dft, c_data_afterfilter, 1).real(); m_r_data[plane] = tm_r_data.block(m_pad_nwires[plane], 0, m_nwires[plane], m_nticks); if (plane == 2) { restore_baseline(m_r_data[plane]); @@ -1315,7 +1361,7 @@ void OmnibusSigProc::decon_2D_charge(int plane) } // do the second round of inverse FFT on wire - Array::array_xxf tm_r_data = Array::idft_cr(c_data_afterfilter, 0); + Array::array_xxf tm_r_data = Aux::inv(m_dft, c_data_afterfilter, 1).real(); m_r_data[plane] = tm_r_data.block(m_pad_nwires[plane], 0, m_nwires[plane], m_nticks); if (plane == 2) { restore_baseline(m_r_data[plane]); @@ -1357,7 +1403,7 @@ bool OmnibusSigProc::operator()(const input_pointer& in, output_pointer& out) ITrace::vector* itraces = new ITrace::vector; // will become shared_ptr. IFrame::trace_summary_t thresholds; - IFrame::trace_list_t wiener_traces, gauss_traces, perframe_traces[3]; + IFrame::trace_list_t wiener_traces, gauss_traces; // here are some trace lists for debug mode IFrame::trace_list_t tight_lf_traces, loose_lf_traces, cleanup_roi_traces, break_roi_loop1_traces, break_roi_loop2_traces, shrink_roi_traces, extend_roi_traces; @@ -1388,6 +1434,7 @@ bool OmnibusSigProc::operator()(const input_pointer& in, output_pointer& out) load_data(in, iplane); // load into a large matrix // initial decon ... decon_2D_init(iplane); // decon in large matrix + check_data(iplane, "after 2D init"); // Form tight ROIs if (iplane != 2) { // induction wire planes @@ -1403,17 +1450,20 @@ bool OmnibusSigProc::operator()(const input_pointer& in, output_pointer& out) decon_2D_tightROI(iplane); roi_form.find_ROI_by_decon_itself(iplane, m_r_data[iplane]); } + check_data(iplane, "after 2D tight ROI"); // [wgu] save decon result after tight LF std::vector dummy; - if (m_use_roi_debug_mode and m_use_roi_refinement) save_data(*itraces, tight_lf_traces, iplane, perwire_rmses, dummy); + if (m_use_roi_debug_mode and m_use_roi_refinement) { + save_data(*itraces, tight_lf_traces, iplane, perwire_rmses, dummy, "tight_lf"); + } // Form loose ROIs if (iplane != 2) { // [wgu] save decon result after loose LF if (m_use_roi_debug_mode) { decon_2D_looseROI_debug_mode(iplane); - save_data(*itraces, loose_lf_traces, iplane, perwire_rmses, dummy); + save_data(*itraces, loose_lf_traces, iplane, perwire_rmses, dummy, "loose_lf"); } if (m_use_roi_refinement) { @@ -1425,7 +1475,11 @@ bool OmnibusSigProc::operator()(const input_pointer& in, output_pointer& out) // [wgu] collection plane does not need loose LF // but save something to be consistent - if (m_use_roi_debug_mode and iplane == 2) save_data(*itraces, loose_lf_traces, iplane, perwire_rmses, dummy); + if (m_use_roi_debug_mode and iplane == 2) { + save_data(*itraces, loose_lf_traces, iplane, perwire_rmses, dummy, "loose_lf"); + } + + check_data(iplane, "after 2D ROI refine"); // Refine ROIs if (m_use_roi_refinement) roi_refine.load_data(iplane, m_r_data[iplane], roi_form); @@ -1436,81 +1490,103 @@ bool OmnibusSigProc::operator()(const input_pointer& in, output_pointer& out) } if (m_use_roi_refinement) { - for (int iplane = 0; iplane != 3; ++iplane) { - auto it = std::find(m_process_planes.begin(), m_process_planes.end(), iplane); - if (it == m_process_planes.end()) continue; + for (int iplane = 0; iplane != 3; ++iplane) { + auto it = std::find(m_process_planes.begin(), m_process_planes.end(), iplane); + if (it == m_process_planes.end()) continue; - // roi_refine.refine_data(iplane, roi_form); + // roi_refine.refine_data(iplane, roi_form); - roi_refine.CleanUpROIs(iplane); - roi_refine.generate_merge_ROIs(iplane); + roi_refine.CleanUpROIs(iplane); + roi_refine.generate_merge_ROIs(iplane); - if (m_use_roi_debug_mode) { - save_roi(*itraces, cleanup_roi_traces, iplane, roi_refine.get_rois_by_plane(iplane)); - } + if (m_use_roi_debug_mode) { + save_roi(*itraces, cleanup_roi_traces, iplane, roi_refine.get_rois_by_plane(iplane)); + } - if (m_use_multi_plane_protection) { - roi_refine.MultiPlaneProtection(iplane, m_anode, m_roi_ch_ch_ident, roi_form, 1000, m_anode->ident() % 2); - save_mproi(*itraces, mp3_roi_traces, iplane, roi_refine.get_mp3_rois()); - roi_refine.MultiPlaneROI(iplane, m_anode, m_roi_ch_ch_ident, roi_form, 1000, m_anode->ident() % 2); - save_mproi(*itraces, mp2_roi_traces, iplane, roi_refine.get_mp2_rois()); + if (m_use_multi_plane_protection) { + roi_refine.MultiPlaneProtection(iplane, m_anode, m_roi_ch_ch_ident, roi_form, 1000, m_anode->ident() % 2); + save_mproi(*itraces, mp3_roi_traces, iplane, roi_refine.get_mp3_rois()); + roi_refine.MultiPlaneROI(iplane, m_anode, m_roi_ch_ch_ident, roi_form, 1000, m_anode->ident() % 2); + save_mproi(*itraces, mp2_roi_traces, iplane, roi_refine.get_mp2_rois()); + } } - } - for (int iplane = 0; iplane != 3; ++iplane) { - auto it = std::find(m_process_planes.begin(), m_process_planes.end(), iplane); - if (it == m_process_planes.end()) continue; + for (int iplane = 0; iplane != 3; ++iplane) { + auto it = std::find(m_process_planes.begin(), m_process_planes.end(), iplane); + if (it == m_process_planes.end()) continue; - const std::vector& perwire_rmses = *perplane_thresholds[iplane]; + const std::vector& perwire_rmses = *perplane_thresholds[iplane]; + + for (int qx = 0; qx != m_r_break_roi_loop; qx++) { + roi_refine.BreakROIs(iplane, roi_form); + roi_refine.CheckROIs(iplane, roi_form); + roi_refine.CleanUpROIs(iplane); + if (m_use_roi_debug_mode) { + if (qx == 0) { + save_roi(*itraces, break_roi_loop1_traces, iplane, roi_refine.get_rois_by_plane(iplane)); + } + if (qx == 1) { + save_roi(*itraces, break_roi_loop2_traces, iplane, roi_refine.get_rois_by_plane(iplane)); + } + } + } - for (int qx = 0; qx != m_r_break_roi_loop; qx++) { - roi_refine.BreakROIs(iplane, roi_form); + roi_refine.ShrinkROIs(iplane, roi_form); + check_data(iplane, "after roi refine shrink"); roi_refine.CheckROIs(iplane, roi_form); + check_data(iplane, "after roi refine check"); roi_refine.CleanUpROIs(iplane); if (m_use_roi_debug_mode) { - if (qx == 0) save_roi(*itraces, break_roi_loop1_traces, iplane, roi_refine.get_rois_by_plane(iplane)); - if (qx == 1) save_roi(*itraces, break_roi_loop2_traces, iplane, roi_refine.get_rois_by_plane(iplane)); + save_roi(*itraces, shrink_roi_traces, iplane, roi_refine.get_rois_by_plane(iplane)); } - } - roi_refine.ShrinkROIs(iplane, roi_form); - roi_refine.CheckROIs(iplane, roi_form); - roi_refine.CleanUpROIs(iplane); - if (m_use_roi_debug_mode) { - save_roi(*itraces, shrink_roi_traces, iplane, roi_refine.get_rois_by_plane(iplane)); - } + if (iplane == 2) { + roi_refine.CleanUpCollectionROIs(); + } + else { + roi_refine.CleanUpInductionROIs(iplane); + } + check_data(iplane, "after roi refine cleanup"); - if (iplane == 2) { - roi_refine.CleanUpCollectionROIs(); - } - else { - roi_refine.CleanUpInductionROIs(iplane); - } - roi_refine.ExtendROIs(iplane); + roi_refine.ExtendROIs(iplane); + check_data(iplane, "after roi refine extend"); - if (m_use_roi_debug_mode) { - save_ext_roi(*itraces, extend_roi_traces, iplane, roi_refine.get_rois_by_plane(iplane)); - } + if (m_use_roi_debug_mode) { + save_ext_roi(*itraces, extend_roi_traces, iplane, roi_refine.get_rois_by_plane(iplane)); + } - // merge results ... - decon_2D_hits(iplane); - roi_refine.apply_roi(iplane, m_r_data[iplane]); - // roi_form.apply_roi(iplane, m_r_data[plane],1); - save_data(*itraces, perframe_traces[iplane], iplane, perwire_rmses, thresholds); - wiener_traces.insert(wiener_traces.end(), perframe_traces[iplane].begin(), perframe_traces[iplane].end()); + // merge results ... + decon_2D_hits(iplane); + check_data(iplane, "after decon 2D hits"); + roi_refine.apply_roi(iplane, m_r_data[iplane]); + check_data(iplane, "after roi refine apply"); + // roi_form.apply_roi(iplane, m_r_data[plane],1); + { + // We only use an intermediate index list here to give + // some clarity to log msg about range added + IFrame::trace_list_t perframe; + save_data(*itraces, perframe, iplane, perwire_rmses, thresholds, "wiener"); + wiener_traces.insert(wiener_traces.end(), perframe.begin(), perframe.end()); + } - decon_2D_charge(iplane); - std::vector dummy_thresholds; - if (m_use_roi_debug_mode) { - save_data(*itraces, decon_charge_traces, iplane, perwire_rmses, thresholds); - } - roi_refine.apply_roi(iplane, m_r_data[iplane]); - // roi_form.apply_roi(iplane, m_r_data[plane],1); - save_data(*itraces, gauss_traces, iplane, perwire_rmses, dummy_thresholds); + decon_2D_charge(iplane); + std::vector dummy_thresholds; + if (m_use_roi_debug_mode) { + save_data(*itraces, decon_charge_traces, iplane, perwire_rmses, thresholds, "decon"); + } + roi_refine.apply_roi(iplane, m_r_data[iplane]); + // roi_form.apply_roi(iplane, m_r_data[plane],1); + { + // We only use an intermediate index list here to give + // some clarity to log msg about range added + IFrame::trace_list_t perframe; + save_data(*itraces, perframe, iplane, perwire_rmses, dummy_thresholds, "gauss"); + gauss_traces.insert(gauss_traces.end(), perframe.begin(), perframe.end()); + } - m_c_data[iplane].resize(0, 0); // clear memory - m_r_data[iplane].resize(0, 0); // clear memory - } + m_c_data[iplane].resize(0, 0); // clear memory + m_r_data[iplane].resize(0, 0); // clear memory + } } SimpleFrame* sframe = new SimpleFrame(in->ident(), in->time(), ITrace::shared_vector(itraces), in->tick(), m_cmm); @@ -1571,5 +1647,5 @@ bool OmnibusSigProc::operator()(const input_pointer& in, output_pointer& out) // Local Variables: // mode: c++ -// c-basic-offset: 2 +// c-basic-offset: 4 // End: diff --git a/sigproc/src/Protodune.cxx b/sigproc/src/Protodune.cxx index 7e69a8e2a..bb3593c18 100644 --- a/sigproc/src/Protodune.cxx +++ b/sigproc/src/Protodune.cxx @@ -13,6 +13,8 @@ #include "WireCellSigProc/Protodune.h" #include "WireCellSigProc/Derivations.h" +#include "WireCellAux/DftTools.h" + #include "WireCellUtil/NamedFactory.h" #include @@ -166,25 +168,6 @@ int LedgeIdentify1(WireCell::Waveform::realseq_t& signal, double baseline, int L } } - // // // find the sharp start edge - // if(ledge == 1&&StartOfLastLedgeCandidate>30){ - // // int edge = 0; - // // int i = StartOfLastLedgeCandidate/UNIT-1; - // // if(averaged.at(i)>averaged.at(i-1)&&averaged.at(i-1)>averaged.at(i-2)){ // find a edge - // // edge = 1; - // // } - // // if(edge == 0) ledge = 0; // if no edge, this is not ledge - // // if((averaged.at(i)-averaged.at(i-2)<10*UNIT)&&(averaged.at(i)-averaged.at(i-3)<10*UNIT)) // slope cut - // // ledge = 0; - // // if(averaged.at(StartOfLastLedgeCandidate/UNIT)-baseline*UNIT>150*UNIT) ledge = 0; // ledge is close to the - // baseline - - // // if(signal.at(tempLedgeEnd) - baseline > 100) ledge=0; // [wgu] ledge end is close to the baseline - // if(averaged.at(tempLedgeEnd/UNIT)-baseline*UNIT>5.*UNIT) ledge = 0; - // // cout << "averaged.at(StartOfLastLedgeCandidate/UNIT) - baseline*UNIT = " << - // averaged.at(StartOfLastLedgeCandidate/UNIT)-baseline*UNIT << std::endl; - // } - if (ledge == 1) { // ledge is close to the baseline if (averaged.at(tempLedgeEnd / UNIT) - baseline * UNIT > 5. * UNIT) ledge = 0; @@ -273,14 +256,6 @@ bool LedgeIdentify(WireCell::Waveform::realseq_t& signal /*TH1F* h2*/, double ba } // find the sharp start edge if (ledge && LedgeStart > 30) { - // int edge = 0; - // int i = LedgeStart/UNIT-1; - // if(averaged.at(i)>averaged.at(i-1)&&averaged.at(i-1)>averaged.at(i-2)){ // find a edge - // edge = 1; - // } - // if(edge == 0) ledge = false; // if no edge, this is not ledge - // if((averaged.at(i)-averaged.at(i-2)<10*UNIT)&&(averaged.at(i)-averaged.at(i-3)<10*UNIT)) // slope cut - // ledge = false; if (averaged.at(LedgeStart / UNIT) - baseline * UNIT > 150 * UNIT) ledge = false; // ledge is close to the baseline } @@ -288,9 +263,6 @@ bool LedgeIdentify(WireCell::Waveform::realseq_t& signal /*TH1F* h2*/, double ba if (ledge && LedgeStart > 20) { double height = 0; if (LedgeStart < 5750) { // calculate the height of edge - // double tempHeight = h2 ->GetBinContent(LedgeStart+1+200) + h2 ->GetBinContent(LedgeStart+1+220) + h2 - // ->GetBinContent(LedgeStart+1+180) + h2 ->GetBinContent(LedgeStart+1+240); height = h2 - // ->GetBinContent(LedgeStart+1) - tempHeight/4; double tempHeight = signal.at(LedgeStart + 200) + signal.at(LedgeStart + 220) + signal.at(LedgeStart + 180) + signal.at(LedgeStart + 240); height = signal.at(LedgeStart) - tempHeight / 4; @@ -302,11 +274,6 @@ bool LedgeIdentify(WireCell::Waveform::realseq_t& signal /*TH1F* h2*/, double ba if (height < 0) height = 80; // norminal value if (height > 30 && LedgeStart < 5900) { // test the decay with a relatively large height double height50 = 0, height100 = 0; - // height50 = h2 ->GetBinContent(LedgeStart+51); - // height100 = h2 ->GetBinContent(LedgeStart+101); - // double height50Pre = h2 ->GetBinContent(LedgeStart+1)- height*(1-exp(-50/100.)); // minimum 100 ticks - // decay time double height100Pre = h2 ->GetBinContent(LedgeStart+1) - height*(1-exp(-100./100)); // - // minimum 100 ticks decay time height50 = signal.at(LedgeStart + 50); height100 = signal.at(LedgeStart + 100); @@ -350,44 +317,9 @@ bool LedgeIdentify(WireCell::Waveform::realseq_t& signal /*TH1F* h2*/, double ba } if (LedgeEnd == 0) LedgeEnd = 6000; } - // done, release the memory - // vector(averaged).swap(averaged); // is it necessary? return ledge; } -// adapted from WCP -// int judgePlateau(int channel, TH1F* h2,double baseline, double & PlateauStart, double & PlateauStartEnd){ -// int continueN = 0; -// int threshold = 200; -// int maximumF = 50; -// int maxBin = h2->GetMaximumBin(); -// for(int i=maxBin+10;i<5880&&iGetBinContent(j+1); -// if(binCh2->GetMaximum()-500) { -// plateau = 0; -// break; -// } -// if(binC>max) max = binC; -// if(binCGetBinContent(k+1) >& st_ranges) { const int nsiglen = signal.size(); @@ -538,7 +472,7 @@ bool Protodune::FftShiftSticky(WireCell::Waveform::realseq_t& signal, double tof } // dft shift for "even" - auto tran_even = WireCell::Waveform::dft(signal_even); + auto tran_even = Aux::fwd_r2c(dft, signal_even); double f0 = 1. / nsublen; const double PI = std::atan(1.0) * 4; for (size_t i = 0; i < tran_even.size(); i++) { @@ -550,12 +484,10 @@ bool Protodune::FftShiftSticky(WireCell::Waveform::realseq_t& signal, double tof tran_even.at(i) = z * std::exp(z1); } // inverse FFT - auto signal_even_fc = WireCell::Waveform::idft(tran_even); - // float scale = 1./tran_even.size(); - // WireCell::Waveform::scale(signal_even_fc, 1./nsublen); + auto signal_even_fc = Aux::inv_c2r(dft, tran_even); // similar to "odd" - auto tran_odd = WireCell::Waveform::dft(signal_odd); + auto tran_odd = Aux::fwd_r2c(dft, signal_odd); f0 = 1. / nsublen2; for (size_t i = 0; i < tran_odd.size(); i++) { double fi = i * f0; @@ -566,7 +498,8 @@ bool Protodune::FftShiftSticky(WireCell::Waveform::realseq_t& signal, double tof tran_odd.at(i) = z * std::exp(z1); } // - auto signal_odd_fc = WireCell::Waveform::idft(tran_odd); + auto signal_odd_fc = Aux::inv_c2r(dft, tran_odd); + // float scale = 1./tran_odd.size(); // WireCell::Waveform::scale(signal_odd_fc, 1./nsublen2); @@ -593,10 +526,11 @@ bool Protodune::FftShiftSticky(WireCell::Waveform::realseq_t& signal, double tof return true; } -bool Protodune::FftScaling(WireCell::Waveform::realseq_t& signal, int nsamples) +bool Protodune::FftScaling(const IDFT::pointer& dft, + WireCell::Waveform::realseq_t& signal, int nsamples) { const int nsiglen = signal.size(); - auto tran = WireCell::Waveform::dft(signal); + auto tran = Aux::fwd_r2c(dft, signal); tran.resize(nsamples); if (nsiglen % 2 == 0) { // ref test_zero_padding.cxx std::rotate(tran.begin() + nsiglen / 2, tran.begin() + nsiglen, tran.end()); @@ -605,7 +539,8 @@ bool Protodune::FftScaling(WireCell::Waveform::realseq_t& signal, int nsamples) std::rotate(tran.begin() + (nsiglen + 1) / 2, tran.begin() + nsiglen, tran.end()); } // inverse FFT - auto signal_fc = WireCell::Waveform::idft(tran); + auto signal_fc = Aux::inv_c2r(dft, tran); + WireCell::Waveform::scale(signal_fc, nsamples / nsiglen); signal = signal_fc; @@ -616,31 +551,6 @@ bool Protodune::FftScaling(WireCell::Waveform::realseq_t& signal, int nsamples) * Classes */ -/* - * Configuration base class used for a couple filters - */ -Protodune::ConfigFilterBase::ConfigFilterBase(const std::string& anode, const std::string& noisedb) - : m_anode_tn(anode) - , m_noisedb_tn(noisedb) -{ -} -Protodune::ConfigFilterBase::~ConfigFilterBase() {} -void Protodune::ConfigFilterBase::configure(const WireCell::Configuration& cfg) -{ - m_anode_tn = get(cfg, "anode", m_anode_tn); - m_anode = Factory::find_tn(m_anode_tn); - m_noisedb_tn = get(cfg, "noisedb", m_noisedb_tn); - m_noisedb = Factory::find_tn(m_noisedb_tn); - // std::cerr << "ConfigFilterBase: \n" << cfg << "\n"; -} -WireCell::Configuration Protodune::ConfigFilterBase::default_configuration() const -{ - Configuration cfg; - cfg["anode"] = m_anode_tn; - cfg["noisedb"] = m_noisedb_tn; - return cfg; -} - Protodune::StickyCodeMitig::StickyCodeMitig(const std::string& anode, const std::string& noisedb, float stky_sig_like_val, float stky_sig_like_rms, int stky_max_len) : m_anode_tn(anode) @@ -664,6 +574,9 @@ void Protodune::StickyCodeMitig::configure(const WireCell::Configuration& cfg) m_noisedb_tn = get(cfg, "noisedb", m_noisedb_tn); m_noisedb = Factory::find_tn(m_noisedb_tn); + std::string dft_tn = get(cfg, "dft", "FftwDFT"); + m_dft = Factory::find_tn(dft_tn); + m_extra_stky.clear(); auto jext = cfg["extra_stky"]; if (!jext.isNull()) { @@ -701,6 +614,7 @@ WireCell::Configuration Protodune::StickyCodeMitig::default_configuration() cons cfg["stky_sig_like_val"] = m_stky_sig_like_val; cfg["stky_sig_like_rms"] = m_stky_sig_like_rms; cfg["stky_max_len"] = m_stky_max_len; + cfg["dft"] = "FftwDFT"; // type-name for the DFT to use return cfg; } @@ -745,11 +659,8 @@ WireCell::Waveform::ChannelMaskMap Protodune::StickyCodeMitig::apply(int ch, sig } // std::cerr << "[wgu] ch: " << ch << " long_stkylen: " << long_stkylen << std::endl; - // auto signal_lc = signal; // copy, need to keep original signal LinearInterpSticky(signal, sticky_rng_list, m_stky_sig_like_val, m_stky_sig_like_rms); - FftInterpSticky(signal, sticky_rng_list); - // FftShiftSticky(signal_lc, 0.5, st_ranges); // alternative approach, shift by 0.5 tick - // signal = signal_lc; + FftInterpSticky(m_dft, signal, sticky_rng_list); // Now calculate the baseline ... std::pair temp = WireCell::Waveform::mean_rms(signal); @@ -794,8 +705,10 @@ WireCell::Waveform::ChannelMaskMap Protodune::StickyCodeMitig::apply(channel_sig return WireCell::Waveform::ChannelMaskMap(); } + Protodune::OneChannelNoise::OneChannelNoise(const std::string& anode, const std::string& noisedb) - : ConfigFilterBase(anode, noisedb) + : m_anode_tn(anode) + , m_noisedb_tn(noisedb) , m_check_partial() // fixme, here too. , m_resmp() { @@ -806,13 +719,12 @@ void Protodune::OneChannelNoise::configure(const WireCell::Configuration& cfg) { m_anode_tn = get(cfg, "anode", m_anode_tn); m_anode = Factory::find_tn(m_anode_tn); - if (!m_anode) { - THROW(KeyError() << errmsg{"failed to get IAnodePlane: " + m_anode_tn}); - } - m_noisedb_tn = get(cfg, "noisedb", m_noisedb_tn); m_noisedb = Factory::find_tn(m_noisedb_tn); + std::string dft_tn = get(cfg, "dft", "FftwDFT"); + m_dft = Factory::find_tn(dft_tn); + m_resmp.clear(); auto jext = cfg["resmp"]; if (!jext.isNull()) { @@ -830,6 +742,7 @@ WireCell::Configuration Protodune::OneChannelNoise::default_configuration() cons Configuration cfg; cfg["anode"] = m_anode_tn; cfg["noisedb"] = m_noisedb_tn; + cfg["dft"] = "FftwDFT"; // type-name for the DFT to use return cfg; } @@ -846,19 +759,12 @@ WireCell::Waveform::ChannelMaskMap Protodune::OneChannelNoise::apply(int ch, sig int smpin = m_resmp.at(ch); int smpout = signal.size(); signal.resize(smpin); - FftScaling(signal, smpout); + FftScaling(m_dft, signal, smpout); // std::cerr << "[wgu] ch: " << ch << " smpin: " << smpin << " smpout: " << smpout << std::endl; } - // if( (ch>=2128 && ch<=2175) // W plane - // || (ch>=1520 && ch<=1559) // V plane - // || (ch>=440 && ch<=479) // U plane - // ){ - // signal.resize(5996); - // FftScaling(signal, 6000); - // } // correct rc undershoot - auto spectrum = WireCell::Waveform::dft(signal); + auto spectrum = Aux::fwd_r2c(m_dft, signal); bool is_partial = m_check_partial(spectrum); // Xin's "IS_RC()" if (!is_partial) { @@ -876,12 +782,6 @@ WireCell::Waveform::ChannelMaskMap Protodune::OneChannelNoise::apply(int ch, sig Microboone::RawAdapativeBaselineAlg(mag); // subtract "linear" background in spectrum auto const& spec = m_noisedb->noise(ch); - // std::cout << "[wgu] " << spec.at(10).real() << std::endl; - // std::cout << "[wgu] " << spec.at(148).real() << std::endl; - // std::cout << "[wgu] " << spec.at(149).real() << std::endl; - // std::cout << "[wgu] " << spec.at(160).real() << std::endl; - // std::cout << "[wgu] " << spec.at(161).real() << std::endl; - // WireCell::Waveform::scale(spectrum, spec); // spec -> freqBins; std::vector > freqBins; @@ -916,12 +816,6 @@ WireCell::Waveform::ChannelMaskMap Protodune::OneChannelNoise::apply(int ch, sig int nslice = iend - istart; // std::cout << "hibin: " << iend << " lobin: " << istart << std::endl; - // } - - // for(int i=0; i<57; i++){ // 150 - 3000th freq bin - // int nslice = 50; - // int istart = 150 + nslice*i; - // int iend = istart + nslice; // std::cerr << istart << " " << iend << std::endl; WireCell::Waveform::realseq_t mag_slice(nslice); // slice of magnitude spectrum std::copy(mag.begin() + istart, mag.begin() + iend, mag_slice.begin()); @@ -931,9 +825,7 @@ WireCell::Waveform::ChannelMaskMap Protodune::OneChannelNoise::apply(int ch, sig if (istart > 1050) { // if(i>17){ cut = stat.first + 3 * stat.second; } - // if(stat.second>1300){ - // cut = stat.first + stat.second; - // } + for (int j = istart; j < iend; j++) { float content = mag.at(j); if (content > cut) { @@ -946,28 +838,6 @@ WireCell::Waveform::ChannelMaskMap Protodune::OneChannelNoise::apply(int ch, sig n_harmonic++; } } - - // for(int j=0; j2000 && content>5.*stat.second){ - // int tbin = istart + j; - // spectrum.at(tbin).real(0); - // spectrum.at(tbin).imag(0); - // spectrum.at(6000+1-tbin).real(0); // FIXME: assuming 6000 ticks - // spectrum.at(6000+1-tbin).imag(0); - // // std::cerr << "[wgu] chan: " << ch << " , freq tick: " << tbin << " , amp: " << content << - // std::endl; - // } - // } - // else if(content>250 && content>10.*stat.second){ - // spectrum.at(j).real(0); - // spectrum.at(j).imag(0); - // spectrum.at(6000+1-j).real(0); // FIXME: assuming 6000 ticks - // spectrum.at(6000+1-j).imag(0); - // } - // } } } @@ -981,7 +851,7 @@ WireCell::Waveform::ChannelMaskMap Protodune::OneChannelNoise::apply(int ch, sig // remove the DC component spectrum.front() = 0; - signal = WireCell::Waveform::idft(spectrum); + signal = Aux::inv_c2r(m_dft, spectrum); // Now calculate the baseline ... std::pair temp = WireCell::Waveform::mean_rms(signal); diff --git a/sigproc/src/SimpleChannelNoiseDB.cxx b/sigproc/src/SimpleChannelNoiseDB.cxx index 97c4240db..cf4b3cdea 100644 --- a/sigproc/src/SimpleChannelNoiseDB.cxx +++ b/sigproc/src/SimpleChannelNoiseDB.cxx @@ -1,4 +1,5 @@ #include "WireCellSigProc/SimpleChannelNoiseDB.h" +#include "WireCellAux/DftTools.h" #include "WireCellUtil/Response.h" #include "WireCellUtil/Binning.h" @@ -31,6 +32,20 @@ SimpleChannelNoiseDB::SimpleChannelNoiseDB(double tick, int nsamples) } SimpleChannelNoiseDB::~SimpleChannelNoiseDB() {} +void SimpleChannelNoiseDB::configure(const WireCell::Configuration& cfg) +{ + std::string dft_tn = get(cfg, "dft", "FftwDFT"); + m_dft = Factory::find_tn(dft_tn); +} + +WireCell::Configuration SimpleChannelNoiseDB::default_configuration() const +{ + Configuration cfg; + cfg["dft"] = "FftwDFT"; // type-name for the DFT to use + return cfg; +} + + double SimpleChannelNoiseDB::nominal_baseline(int channel) const { const int ind = chind(channel); @@ -240,7 +255,7 @@ void SimpleChannelNoiseDB::set_rcrc_constant(const std::vector& channels, d // auto signal = rcres.generate(WireCell::Binning(m_nsamples, 0, m_nsamples*m_tick)); auto signal = rcres.generate(WireCell::Waveform::Domain(0, m_nsamples * m_tick), m_nsamples); - Waveform::compseq_t spectrum = Waveform::dft(signal); + Waveform::compseq_t spectrum = Aux::fwd_r2c(m_dft, signal); // std::cout << rcrc << " " << m_tick << " " << m_nsamples << " " << signal.front() << " " << signal.at(1) << " " << // signal.at(2) << std::endl; @@ -295,8 +310,9 @@ void SimpleChannelNoiseDB::set_gains_shapings(const std::vector& channels, auto to_sig = to_ce.generate(WireCell::Waveform::Domain(0, m_nsamples * m_tick), m_nsamples); auto from_sig = from_ce.generate(WireCell::Waveform::Domain(0, m_nsamples * m_tick), m_nsamples); - auto to_filt = Waveform::dft(to_sig); - auto from_filt = Waveform::dft(from_sig); + auto to_filt = Aux::fwd_r2c(m_dft, to_sig); + + auto from_filt = Aux::fwd_r2c(m_dft, from_sig); // auto from_filt_sum = Waveform::sum(from_filt); // auto to_filt_sum = Waveform::sum(to_filt); diff --git a/sigproc/test/test_partial.cxx b/sigproc/test/test_partial.cxx index 3756d297b..3fc539077 100644 --- a/sigproc/test/test_partial.cxx +++ b/sigproc/test/test_partial.cxx @@ -1,5 +1,9 @@ #include "WireCellSigProc/Diagnostics.h" -#include "WireCellUtil/Waveform.h" + +#include "WireCellAux/DftTools.h" +#include "WireCellUtil/NamedFactory.h" +#include "WireCellUtil/PluginManager.h" + #include "WireCellUtil/Testing.h" #include @@ -16,7 +20,11 @@ using namespace WireCell::SigProc; int main(int argc, char* argv[]) { - auto spectrum = Waveform::dft(horig); + PluginManager& pm = PluginManager::instance(); + pm.add("WireCellAux"); + auto idft = Factory::lookup_tn("FftwDFT"); + + auto spectrum = Aux::fwd_r2c(idft, horig); Diagnostics::Partial m_check_partial; bool is_partial = m_check_partial(spectrum); Assert(is_partial); diff --git a/sigproc/test/test_simple_channel_noisedb.cxx b/sigproc/test/test_simple_channel_noisedb.cxx index 59c9a9880..c03c5b852 100644 --- a/sigproc/test/test_simple_channel_noisedb.cxx +++ b/sigproc/test/test_simple_channel_noisedb.cxx @@ -1,6 +1,9 @@ #include "WireCellUtil/Testing.h" #include "WireCellSigProc/SimpleChannelNoiseDB.h" + +#include "WireCellUtil/PluginManager.h" +#include "WireCellUtil/NamedFactory.h" #include "WireCellUtil/Units.h" #include @@ -11,10 +14,15 @@ using namespace WireCell::SigProc; int main() { + PluginManager& pm = PluginManager::instance(); + pm.add("WireCellAux"); + Factory::lookup_tn("FftwDFT"); + const int nsamples = 5432; const double tick = 1.0 * units::ms; SimpleChannelNoiseDB cndb(tick, nsamples); + cndb.configure(cndb.default_configuration()); Assert(cndb.sample_time() == tick); Assert(cndb.nominal_baseline(0) == 0.0); diff --git a/sigproc/test/test_zero_padding.cxx b/sigproc/test/test_zero_padding.cxx index c8d82d775..2182cbc27 100644 --- a/sigproc/test/test_zero_padding.cxx +++ b/sigproc/test/test_zero_padding.cxx @@ -1,6 +1,10 @@ // Example for FFT resampling with zero-padding tricks #include "WireCellUtil/Waveform.h" +#include "WireCellAux/DftTools.h" +#include "WireCellUtil/NamedFactory.h" +#include "WireCellUtil/PluginManager.h" + #include // for FFT @@ -14,10 +18,14 @@ using namespace WireCell; int main() { + PluginManager& pm = PluginManager::instance(); + pm.add("WireCellAux"); + auto idft = Factory::lookup_tn("FftwDFT"); + std::vector a = {1, 2, 3, 2, 1}; // can be sampled to 10 ticks: 1 , 1.35279 , 2 , 2.69443 , 3 , 2.69443 , 2 , 1.35279 , 1 , 0.905573 - auto tran = WireCell::Waveform::dft(a); + auto tran = Aux::fwd_r2c(idft, a); std::cout << " tran = " << std::endl; std::cout << tran.size() << std::endl; @@ -48,7 +56,7 @@ int main() std::cout << std::endl; // inverse FFT - auto b = WireCell::Waveform::idft(tran); + auto b = Aux::inv_c2r(idft, tran); float scale = tran.size() / inSmps; // std::cout << " b = " << std::endl; diff --git a/util/inc/WireCellUtil/Array.h b/util/inc/WireCellUtil/Array.h index 00504775c..1bcd879f5 100644 --- a/util/inc/WireCellUtil/Array.h +++ b/util/inc/WireCellUtil/Array.h @@ -54,65 +54,6 @@ namespace WireCell { /// A complex, 2D array typedef Eigen::ArrayXXcf array_xxc; - /** Perform full, 2D discrete Fourier transform on a real 2D - array. - - The full 2D DFT first performs a 1D DFT (real->complex) on - each individual row and then a 1D DFT (complex->complex) - on each resulting column. - - const_shared_array_xxf arr = ...; - const_shared_array_xxc spec = dft(*arr); - - // ... - - const_shared_array_xxf arr2 = idft(*spec); - */ - array_xxc dft(const array_xxf& arr); - array_xxf idft(const array_xxc& arr); - - /** Partial, 1D DFT and inverse DFT along one dimension of an - * array. Each row is transformed if dim=0, each column if - * dim=1. The transfer is either real->complex (rc), - * complex->complex(cc) or complex->real(cr). - * - * The full 2D DFT should be used unless an intermediate - * filter is required as it will avoid producing some - * temporaries. - * - * Conceptually: - * - * auto xxc = dft(xxf); - * - * is equivalent to - * - * auto tmp = dft_rc(xxf, 0); - * auto xxc = dft_cc(tmp, 1); - * - * and: - * - * auto xxf = idft(xxc) - * - * is equivalent to: - * - * auto tmp = idft_cc(xxc, 1); - * auto xxf = idft_rc(tmp, 0); - */ - array_xxc dft_rc(const array_xxf& arr, int dim = 0); - array_xxc dft_cc(const array_xxc& arr, int dim = 1); - array_xxc idft_cc(const array_xxc& arr, int dim = 1); - array_xxf idft_cr(const array_xxc& arr, int dim = 0); - - /** Perform 2D deconvolution. - - This will perform a 2D forward DFT, do an - element-by-element multiplication of that - periodicity/frequency space matrix by the filter and then - perform an 2D inverse DFT. - - */ - array_xxf deconv(const array_xxf& arr, const array_xxc& filter); - /** downsample a 2D array along one axis by k * simple average of all numbers in a bin * e.g: MxN -> Mxfloor(N/k) diff --git a/util/inc/WireCellUtil/CLI11.hpp b/util/inc/WireCellUtil/CLI11.hpp new file mode 100644 index 000000000..dcb57c6c6 --- /dev/null +++ b/util/inc/WireCellUtil/CLI11.hpp @@ -0,0 +1,9066 @@ +// CLI11: Version 2.1.2 +// Originally designed by Henry Schreiner +// https://github.com/CLIUtils/CLI11 +// +// This is a standalone header file generated by MakeSingleHeader.py in CLI11/scripts +// from: v2.1.2 +// +// CLI11 2.1.2 Copyright (c) 2017-2021 University of Cincinnati, developed by Henry +// Schreiner under NSF AWARD 1414736. All rights reserved. +// +// Redistribution and use in source and binary forms of CLI11, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// 3. Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +// ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +// Standard combined includes: +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#define CLI11_VERSION_MAJOR 2 +#define CLI11_VERSION_MINOR 1 +#define CLI11_VERSION_PATCH 2 +#define CLI11_VERSION "2.1.2" + + + + +// The following version macro is very similar to the one in pybind11 +#if !(defined(_MSC_VER) && __cplusplus == 199711L) && !defined(__INTEL_COMPILER) +#if __cplusplus >= 201402L +#define CLI11_CPP14 +#if __cplusplus >= 201703L +#define CLI11_CPP17 +#if __cplusplus > 201703L +#define CLI11_CPP20 +#endif +#endif +#endif +#elif defined(_MSC_VER) && __cplusplus == 199711L +// MSVC sets _MSVC_LANG rather than __cplusplus (supposedly until the standard is fully implemented) +// Unless you use the /Zc:__cplusplus flag on Visual Studio 2017 15.7 Preview 3 or newer +#if _MSVC_LANG >= 201402L +#define CLI11_CPP14 +#if _MSVC_LANG > 201402L && _MSC_VER >= 1910 +#define CLI11_CPP17 +#if __MSVC_LANG > 201703L && _MSC_VER >= 1910 +#define CLI11_CPP20 +#endif +#endif +#endif +#endif + +#if defined(CLI11_CPP14) +#define CLI11_DEPRECATED(reason) [[deprecated(reason)]] +#elif defined(_MSC_VER) +#define CLI11_DEPRECATED(reason) __declspec(deprecated(reason)) +#else +#define CLI11_DEPRECATED(reason) __attribute__((deprecated(reason))) +#endif + + + + +// C standard library +// Only needed for existence checking +#if defined CLI11_CPP17 && defined __has_include && !defined CLI11_HAS_FILESYSTEM +#if __has_include() +// Filesystem cannot be used if targeting macOS < 10.15 +#if defined __MAC_OS_X_VERSION_MIN_REQUIRED && __MAC_OS_X_VERSION_MIN_REQUIRED < 101500 +#define CLI11_HAS_FILESYSTEM 0 +#else +#include +#if defined __cpp_lib_filesystem && __cpp_lib_filesystem >= 201703 +#if defined _GLIBCXX_RELEASE && _GLIBCXX_RELEASE >= 9 +#define CLI11_HAS_FILESYSTEM 1 +#elif defined(__GLIBCXX__) +// if we are using gcc and Version <9 default to no filesystem +#define CLI11_HAS_FILESYSTEM 0 +#else +#define CLI11_HAS_FILESYSTEM 1 +#endif +#else +#define CLI11_HAS_FILESYSTEM 0 +#endif +#endif +#endif +#endif + +#if defined CLI11_HAS_FILESYSTEM && CLI11_HAS_FILESYSTEM > 0 +#include // NOLINT(build/include) +#else +#include +#include +#endif + + + +namespace CLI { + + +/// Include the items in this namespace to get free conversion of enums to/from streams. +/// (This is available inside CLI as well, so CLI11 will use this without a using statement). +namespace enums { + +/// output streaming for enumerations +template ::value>::type> +std::ostream &operator<<(std::ostream &in, const T &item) { + // make sure this is out of the detail namespace otherwise it won't be found when needed + return in << static_cast::type>(item); +} + +} // namespace enums + +/// Export to CLI namespace +using enums::operator<<; + +namespace detail { +/// a constant defining an expected max vector size defined to be a big number that could be multiplied by 4 and not +/// produce overflow for some expected uses +constexpr int expected_max_vector_size{1 << 29}; +// Based on http://stackoverflow.com/questions/236129/split-a-string-in-c +/// Split a string by a delim +inline std::vector split(const std::string &s, char delim) { + std::vector elems; + // Check to see if empty string, give consistent result + if(s.empty()) { + elems.emplace_back(); + } else { + std::stringstream ss; + ss.str(s); + std::string item; + while(std::getline(ss, item, delim)) { + elems.push_back(item); + } + } + return elems; +} + +/// Simple function to join a string +template std::string join(const T &v, std::string delim = ",") { + std::ostringstream s; + auto beg = std::begin(v); + auto end = std::end(v); + if(beg != end) + s << *beg++; + while(beg != end) { + s << delim << *beg++; + } + return s.str(); +} + +/// Simple function to join a string from processed elements +template ::value>::type> +std::string join(const T &v, Callable func, std::string delim = ",") { + std::ostringstream s; + auto beg = std::begin(v); + auto end = std::end(v); + auto loc = s.tellp(); + while(beg != end) { + auto nloc = s.tellp(); + if(nloc > loc) { + s << delim; + loc = nloc; + } + s << func(*beg++); + } + return s.str(); +} + +/// Join a string in reverse order +template std::string rjoin(const T &v, std::string delim = ",") { + std::ostringstream s; + for(std::size_t start = 0; start < v.size(); start++) { + if(start > 0) + s << delim; + s << v[v.size() - start - 1]; + } + return s.str(); +} + +// Based roughly on http://stackoverflow.com/questions/25829143/c-trim-whitespace-from-a-string + +/// Trim whitespace from left of string +inline std::string <rim(std::string &str) { + auto it = std::find_if(str.begin(), str.end(), [](char ch) { return !std::isspace(ch, std::locale()); }); + str.erase(str.begin(), it); + return str; +} + +/// Trim anything from left of string +inline std::string <rim(std::string &str, const std::string &filter) { + auto it = std::find_if(str.begin(), str.end(), [&filter](char ch) { return filter.find(ch) == std::string::npos; }); + str.erase(str.begin(), it); + return str; +} + +/// Trim whitespace from right of string +inline std::string &rtrim(std::string &str) { + auto it = std::find_if(str.rbegin(), str.rend(), [](char ch) { return !std::isspace(ch, std::locale()); }); + str.erase(it.base(), str.end()); + return str; +} + +/// Trim anything from right of string +inline std::string &rtrim(std::string &str, const std::string &filter) { + auto it = + std::find_if(str.rbegin(), str.rend(), [&filter](char ch) { return filter.find(ch) == std::string::npos; }); + str.erase(it.base(), str.end()); + return str; +} + +/// Trim whitespace from string +inline std::string &trim(std::string &str) { return ltrim(rtrim(str)); } + +/// Trim anything from string +inline std::string &trim(std::string &str, const std::string filter) { return ltrim(rtrim(str, filter), filter); } + +/// Make a copy of the string and then trim it +inline std::string trim_copy(const std::string &str) { + std::string s = str; + return trim(s); +} + +/// remove quotes at the front and back of a string either '"' or '\'' +inline std::string &remove_quotes(std::string &str) { + if(str.length() > 1 && (str.front() == '"' || str.front() == '\'')) { + if(str.front() == str.back()) { + str.pop_back(); + str.erase(str.begin(), str.begin() + 1); + } + } + return str; +} + +/// Add a leader to the beginning of all new lines (nothing is added +/// at the start of the first line). `"; "` would be for ini files +/// +/// Can't use Regex, or this would be a subs. +inline std::string fix_newlines(const std::string &leader, std::string input) { + std::string::size_type n = 0; + while(n != std::string::npos && n < input.size()) { + n = input.find('\n', n); + if(n != std::string::npos) { + input = input.substr(0, n + 1) + leader + input.substr(n + 1); + n += leader.size(); + } + } + return input; +} + +/// Make a copy of the string and then trim it, any filter string can be used (any char in string is filtered) +inline std::string trim_copy(const std::string &str, const std::string &filter) { + std::string s = str; + return trim(s, filter); +} +/// Print a two part "help" string +inline std::ostream &format_help(std::ostream &out, std::string name, const std::string &description, std::size_t wid) { + name = " " + name; + out << std::setw(static_cast(wid)) << std::left << name; + if(!description.empty()) { + if(name.length() >= wid) + out << "\n" << std::setw(static_cast(wid)) << ""; + for(const char c : description) { + out.put(c); + if(c == '\n') { + out << std::setw(static_cast(wid)) << ""; + } + } + } + out << "\n"; + return out; +} + +/// Print subcommand aliases +inline std::ostream &format_aliases(std::ostream &out, const std::vector &aliases, std::size_t wid) { + if(!aliases.empty()) { + out << std::setw(static_cast(wid)) << " aliases: "; + bool front = true; + for(const auto &alias : aliases) { + if(!front) { + out << ", "; + } else { + front = false; + } + out << detail::fix_newlines(" ", alias); + } + out << "\n"; + } + return out; +} + +/// Verify the first character of an option +/// - is a trigger character, ! has special meaning and new lines would just be annoying to deal with +template bool valid_first_char(T c) { return ((c != '-') && (c != '!') && (c != ' ') && c != '\n'); } + +/// Verify following characters of an option +template bool valid_later_char(T c) { + // = and : are value separators, { has special meaning for option defaults, + // and \n would just be annoying to deal with in many places allowing space here has too much potential for + // inadvertent entry errors and bugs + return ((c != '=') && (c != ':') && (c != '{') && (c != ' ') && c != '\n'); +} + +/// Verify an option/subcommand name +inline bool valid_name_string(const std::string &str) { + if(str.empty() || !valid_first_char(str[0])) { + return false; + } + auto e = str.end(); + for(auto c = str.begin() + 1; c != e; ++c) + if(!valid_later_char(*c)) + return false; + return true; +} + +/// Verify an app name +inline bool valid_alias_name_string(const std::string &str) { + static const std::string badChars(std::string("\n") + '\0'); + return (str.find_first_of(badChars) == std::string::npos); +} + +/// check if a string is a container segment separator (empty or "%%") +inline bool is_separator(const std::string &str) { + static const std::string sep("%%"); + return (str.empty() || str == sep); +} + +/// Verify that str consists of letters only +inline bool isalpha(const std::string &str) { + return std::all_of(str.begin(), str.end(), [](char c) { return std::isalpha(c, std::locale()); }); +} + +/// Return a lower case version of a string +inline std::string to_lower(std::string str) { + std::transform(std::begin(str), std::end(str), std::begin(str), [](const std::string::value_type &x) { + return std::tolower(x, std::locale()); + }); + return str; +} + +/// remove underscores from a string +inline std::string remove_underscore(std::string str) { + str.erase(std::remove(std::begin(str), std::end(str), '_'), std::end(str)); + return str; +} + +/// Find and replace a substring with another substring +inline std::string find_and_replace(std::string str, std::string from, std::string to) { + + std::size_t start_pos = 0; + + while((start_pos = str.find(from, start_pos)) != std::string::npos) { + str.replace(start_pos, from.length(), to); + start_pos += to.length(); + } + + return str; +} + +/// check if the flag definitions has possible false flags +inline bool has_default_flag_values(const std::string &flags) { + return (flags.find_first_of("{!") != std::string::npos); +} + +inline void remove_default_flag_values(std::string &flags) { + auto loc = flags.find_first_of('{', 2); + while(loc != std::string::npos) { + auto finish = flags.find_first_of("},", loc + 1); + if((finish != std::string::npos) && (flags[finish] == '}')) { + flags.erase(flags.begin() + static_cast(loc), + flags.begin() + static_cast(finish) + 1); + } + loc = flags.find_first_of('{', loc + 1); + } + flags.erase(std::remove(flags.begin(), flags.end(), '!'), flags.end()); +} + +/// Check if a string is a member of a list of strings and optionally ignore case or ignore underscores +inline std::ptrdiff_t find_member(std::string name, + const std::vector names, + bool ignore_case = false, + bool ignore_underscore = false) { + auto it = std::end(names); + if(ignore_case) { + if(ignore_underscore) { + name = detail::to_lower(detail::remove_underscore(name)); + it = std::find_if(std::begin(names), std::end(names), [&name](std::string local_name) { + return detail::to_lower(detail::remove_underscore(local_name)) == name; + }); + } else { + name = detail::to_lower(name); + it = std::find_if(std::begin(names), std::end(names), [&name](std::string local_name) { + return detail::to_lower(local_name) == name; + }); + } + + } else if(ignore_underscore) { + name = detail::remove_underscore(name); + it = std::find_if(std::begin(names), std::end(names), [&name](std::string local_name) { + return detail::remove_underscore(local_name) == name; + }); + } else { + it = std::find(std::begin(names), std::end(names), name); + } + + return (it != std::end(names)) ? (it - std::begin(names)) : (-1); +} + +/// Find a trigger string and call a modify callable function that takes the current string and starting position of the +/// trigger and returns the position in the string to search for the next trigger string +template inline std::string find_and_modify(std::string str, std::string trigger, Callable modify) { + std::size_t start_pos = 0; + while((start_pos = str.find(trigger, start_pos)) != std::string::npos) { + start_pos = modify(str, start_pos); + } + return str; +} + +/// Split a string '"one two" "three"' into 'one two', 'three' +/// Quote characters can be ` ' or " +inline std::vector split_up(std::string str, char delimiter = '\0') { + + const std::string delims("\'\"`"); + auto find_ws = [delimiter](char ch) { + return (delimiter == '\0') ? (std::isspace(ch, std::locale()) != 0) : (ch == delimiter); + }; + trim(str); + + std::vector output; + bool embeddedQuote = false; + char keyChar = ' '; + while(!str.empty()) { + if(delims.find_first_of(str[0]) != std::string::npos) { + keyChar = str[0]; + auto end = str.find_first_of(keyChar, 1); + while((end != std::string::npos) && (str[end - 1] == '\\')) { // deal with escaped quotes + end = str.find_first_of(keyChar, end + 1); + embeddedQuote = true; + } + if(end != std::string::npos) { + output.push_back(str.substr(1, end - 1)); + if(end + 2 < str.size()) { + str = str.substr(end + 2); + } else { + str.clear(); + } + + } else { + output.push_back(str.substr(1)); + str = ""; + } + } else { + auto it = std::find_if(std::begin(str), std::end(str), find_ws); + if(it != std::end(str)) { + std::string value = std::string(str.begin(), it); + output.push_back(value); + str = std::string(it + 1, str.end()); + } else { + output.push_back(str); + str = ""; + } + } + // transform any embedded quotes into the regular character + if(embeddedQuote) { + output.back() = find_and_replace(output.back(), std::string("\\") + keyChar, std::string(1, keyChar)); + embeddedQuote = false; + } + trim(str); + } + return output; +} + +/// This function detects an equal or colon followed by an escaped quote after an argument +/// then modifies the string to replace the equality with a space. This is needed +/// to allow the split up function to work properly and is intended to be used with the find_and_modify function +/// the return value is the offset+1 which is required by the find_and_modify function. +inline std::size_t escape_detect(std::string &str, std::size_t offset) { + auto next = str[offset + 1]; + if((next == '\"') || (next == '\'') || (next == '`')) { + auto astart = str.find_last_of("-/ \"\'`", offset - 1); + if(astart != std::string::npos) { + if(str[astart] == ((str[offset] == '=') ? '-' : '/')) + str[offset] = ' '; // interpret this as a space so the split_up works properly + } + } + return offset + 1; +} + +/// Add quotes if the string contains spaces +inline std::string &add_quotes_if_needed(std::string &str) { + if((str.front() != '"' && str.front() != '\'') || str.front() != str.back()) { + char quote = str.find('"') < str.find('\'') ? '\'' : '"'; + if(str.find(' ') != std::string::npos) { + str.insert(0, 1, quote); + str.append(1, quote); + } + } + return str; +} + +} // namespace detail + + + + +// Use one of these on all error classes. +// These are temporary and are undef'd at the end of this file. +#define CLI11_ERROR_DEF(parent, name) \ + protected: \ + name(std::string ename, std::string msg, int exit_code) : parent(std::move(ename), std::move(msg), exit_code) {} \ + name(std::string ename, std::string msg, ExitCodes exit_code) \ + : parent(std::move(ename), std::move(msg), exit_code) {} \ + \ + public: \ + name(std::string msg, ExitCodes exit_code) : parent(#name, std::move(msg), exit_code) {} \ + name(std::string msg, int exit_code) : parent(#name, std::move(msg), exit_code) {} + +// This is added after the one above if a class is used directly and builds its own message +#define CLI11_ERROR_SIMPLE(name) \ + explicit name(std::string msg) : name(#name, msg, ExitCodes::name) {} + +/// These codes are part of every error in CLI. They can be obtained from e using e.exit_code or as a quick shortcut, +/// int values from e.get_error_code(). +enum class ExitCodes { + Success = 0, + IncorrectConstruction = 100, + BadNameString, + OptionAlreadyAdded, + FileError, + ConversionError, + ValidationError, + RequiredError, + RequiresError, + ExcludesError, + ExtrasError, + ConfigError, + InvalidError, + HorribleError, + OptionNotFound, + ArgumentMismatch, + BaseClass = 127 +}; + +// Error definitions + +/// @defgroup error_group Errors +/// @brief Errors thrown by CLI11 +/// +/// These are the errors that can be thrown. Some of them, like CLI::Success, are not really errors. +/// @{ + +/// All errors derive from this one +class Error : public std::runtime_error { + int actual_exit_code; + std::string error_name{"Error"}; + + public: + int get_exit_code() const { return actual_exit_code; } + + std::string get_name() const { return error_name; } + + Error(std::string name, std::string msg, int exit_code = static_cast(ExitCodes::BaseClass)) + : runtime_error(msg), actual_exit_code(exit_code), error_name(std::move(name)) {} + + Error(std::string name, std::string msg, ExitCodes exit_code) : Error(name, msg, static_cast(exit_code)) {} +}; + +// Note: Using Error::Error constructors does not work on GCC 4.7 + +/// Construction errors (not in parsing) +class ConstructionError : public Error { + CLI11_ERROR_DEF(Error, ConstructionError) +}; + +/// Thrown when an option is set to conflicting values (non-vector and multi args, for example) +class IncorrectConstruction : public ConstructionError { + CLI11_ERROR_DEF(ConstructionError, IncorrectConstruction) + CLI11_ERROR_SIMPLE(IncorrectConstruction) + static IncorrectConstruction PositionalFlag(std::string name) { + return IncorrectConstruction(name + ": Flags cannot be positional"); + } + static IncorrectConstruction Set0Opt(std::string name) { + return IncorrectConstruction(name + ": Cannot set 0 expected, use a flag instead"); + } + static IncorrectConstruction SetFlag(std::string name) { + return IncorrectConstruction(name + ": Cannot set an expected number for flags"); + } + static IncorrectConstruction ChangeNotVector(std::string name) { + return IncorrectConstruction(name + ": You can only change the expected arguments for vectors"); + } + static IncorrectConstruction AfterMultiOpt(std::string name) { + return IncorrectConstruction( + name + ": You can't change expected arguments after you've changed the multi option policy!"); + } + static IncorrectConstruction MissingOption(std::string name) { + return IncorrectConstruction("Option " + name + " is not defined"); + } + static IncorrectConstruction MultiOptionPolicy(std::string name) { + return IncorrectConstruction(name + ": multi_option_policy only works for flags and exact value options"); + } +}; + +/// Thrown on construction of a bad name +class BadNameString : public ConstructionError { + CLI11_ERROR_DEF(ConstructionError, BadNameString) + CLI11_ERROR_SIMPLE(BadNameString) + static BadNameString OneCharName(std::string name) { return BadNameString("Invalid one char name: " + name); } + static BadNameString BadLongName(std::string name) { return BadNameString("Bad long name: " + name); } + static BadNameString DashesOnly(std::string name) { + return BadNameString("Must have a name, not just dashes: " + name); + } + static BadNameString MultiPositionalNames(std::string name) { + return BadNameString("Only one positional name allowed, remove: " + name); + } +}; + +/// Thrown when an option already exists +class OptionAlreadyAdded : public ConstructionError { + CLI11_ERROR_DEF(ConstructionError, OptionAlreadyAdded) + explicit OptionAlreadyAdded(std::string name) + : OptionAlreadyAdded(name + " is already added", ExitCodes::OptionAlreadyAdded) {} + static OptionAlreadyAdded Requires(std::string name, std::string other) { + return OptionAlreadyAdded(name + " requires " + other, ExitCodes::OptionAlreadyAdded); + } + static OptionAlreadyAdded Excludes(std::string name, std::string other) { + return OptionAlreadyAdded(name + " excludes " + other, ExitCodes::OptionAlreadyAdded); + } +}; + +// Parsing errors + +/// Anything that can error in Parse +class ParseError : public Error { + CLI11_ERROR_DEF(Error, ParseError) +}; + +// Not really "errors" + +/// This is a successful completion on parsing, supposed to exit +class Success : public ParseError { + CLI11_ERROR_DEF(ParseError, Success) + Success() : Success("Successfully completed, should be caught and quit", ExitCodes::Success) {} +}; + +/// -h or --help on command line +class CallForHelp : public Success { + CLI11_ERROR_DEF(Success, CallForHelp) + CallForHelp() : CallForHelp("This should be caught in your main function, see examples", ExitCodes::Success) {} +}; + +/// Usually something like --help-all on command line +class CallForAllHelp : public Success { + CLI11_ERROR_DEF(Success, CallForAllHelp) + CallForAllHelp() + : CallForAllHelp("This should be caught in your main function, see examples", ExitCodes::Success) {} +}; + +/// -v or --version on command line +class CallForVersion : public Success { + CLI11_ERROR_DEF(Success, CallForVersion) + CallForVersion() + : CallForVersion("This should be caught in your main function, see examples", ExitCodes::Success) {} +}; + +/// Does not output a diagnostic in CLI11_PARSE, but allows main() to return with a specific error code. +class RuntimeError : public ParseError { + CLI11_ERROR_DEF(ParseError, RuntimeError) + explicit RuntimeError(int exit_code = 1) : RuntimeError("Runtime error", exit_code) {} +}; + +/// Thrown when parsing an INI file and it is missing +class FileError : public ParseError { + CLI11_ERROR_DEF(ParseError, FileError) + CLI11_ERROR_SIMPLE(FileError) + static FileError Missing(std::string name) { return FileError(name + " was not readable (missing?)"); } +}; + +/// Thrown when conversion call back fails, such as when an int fails to coerce to a string +class ConversionError : public ParseError { + CLI11_ERROR_DEF(ParseError, ConversionError) + CLI11_ERROR_SIMPLE(ConversionError) + ConversionError(std::string member, std::string name) + : ConversionError("The value " + member + " is not an allowed value for " + name) {} + ConversionError(std::string name, std::vector results) + : ConversionError("Could not convert: " + name + " = " + detail::join(results)) {} + static ConversionError TooManyInputsFlag(std::string name) { + return ConversionError(name + ": too many inputs for a flag"); + } + static ConversionError TrueFalse(std::string name) { + return ConversionError(name + ": Should be true/false or a number"); + } +}; + +/// Thrown when validation of results fails +class ValidationError : public ParseError { + CLI11_ERROR_DEF(ParseError, ValidationError) + CLI11_ERROR_SIMPLE(ValidationError) + explicit ValidationError(std::string name, std::string msg) : ValidationError(name + ": " + msg) {} +}; + +/// Thrown when a required option is missing +class RequiredError : public ParseError { + CLI11_ERROR_DEF(ParseError, RequiredError) + explicit RequiredError(std::string name) : RequiredError(name + " is required", ExitCodes::RequiredError) {} + static RequiredError Subcommand(std::size_t min_subcom) { + if(min_subcom == 1) { + return RequiredError("A subcommand"); + } + return RequiredError("Requires at least " + std::to_string(min_subcom) + " subcommands", + ExitCodes::RequiredError); + } + static RequiredError + Option(std::size_t min_option, std::size_t max_option, std::size_t used, const std::string &option_list) { + if((min_option == 1) && (max_option == 1) && (used == 0)) + return RequiredError("Exactly 1 option from [" + option_list + "]"); + if((min_option == 1) && (max_option == 1) && (used > 1)) { + return RequiredError("Exactly 1 option from [" + option_list + "] is required and " + std::to_string(used) + + " were given", + ExitCodes::RequiredError); + } + if((min_option == 1) && (used == 0)) + return RequiredError("At least 1 option from [" + option_list + "]"); + if(used < min_option) { + return RequiredError("Requires at least " + std::to_string(min_option) + " options used and only " + + std::to_string(used) + "were given from [" + option_list + "]", + ExitCodes::RequiredError); + } + if(max_option == 1) + return RequiredError("Requires at most 1 options be given from [" + option_list + "]", + ExitCodes::RequiredError); + + return RequiredError("Requires at most " + std::to_string(max_option) + " options be used and " + + std::to_string(used) + "were given from [" + option_list + "]", + ExitCodes::RequiredError); + } +}; + +/// Thrown when the wrong number of arguments has been received +class ArgumentMismatch : public ParseError { + CLI11_ERROR_DEF(ParseError, ArgumentMismatch) + CLI11_ERROR_SIMPLE(ArgumentMismatch) + ArgumentMismatch(std::string name, int expected, std::size_t received) + : ArgumentMismatch(expected > 0 ? ("Expected exactly " + std::to_string(expected) + " arguments to " + name + + ", got " + std::to_string(received)) + : ("Expected at least " + std::to_string(-expected) + " arguments to " + name + + ", got " + std::to_string(received)), + ExitCodes::ArgumentMismatch) {} + + static ArgumentMismatch AtLeast(std::string name, int num, std::size_t received) { + return ArgumentMismatch(name + ": At least " + std::to_string(num) + " required but received " + + std::to_string(received)); + } + static ArgumentMismatch AtMost(std::string name, int num, std::size_t received) { + return ArgumentMismatch(name + ": At Most " + std::to_string(num) + " required but received " + + std::to_string(received)); + } + static ArgumentMismatch TypedAtLeast(std::string name, int num, std::string type) { + return ArgumentMismatch(name + ": " + std::to_string(num) + " required " + type + " missing"); + } + static ArgumentMismatch FlagOverride(std::string name) { + return ArgumentMismatch(name + " was given a disallowed flag override"); + } +}; + +/// Thrown when a requires option is missing +class RequiresError : public ParseError { + CLI11_ERROR_DEF(ParseError, RequiresError) + RequiresError(std::string curname, std::string subname) + : RequiresError(curname + " requires " + subname, ExitCodes::RequiresError) {} +}; + +/// Thrown when an excludes option is present +class ExcludesError : public ParseError { + CLI11_ERROR_DEF(ParseError, ExcludesError) + ExcludesError(std::string curname, std::string subname) + : ExcludesError(curname + " excludes " + subname, ExitCodes::ExcludesError) {} +}; + +/// Thrown when too many positionals or options are found +class ExtrasError : public ParseError { + CLI11_ERROR_DEF(ParseError, ExtrasError) + explicit ExtrasError(std::vector args) + : ExtrasError((args.size() > 1 ? "The following arguments were not expected: " + : "The following argument was not expected: ") + + detail::rjoin(args, " "), + ExitCodes::ExtrasError) {} + ExtrasError(const std::string &name, std::vector args) + : ExtrasError(name, + (args.size() > 1 ? "The following arguments were not expected: " + : "The following argument was not expected: ") + + detail::rjoin(args, " "), + ExitCodes::ExtrasError) {} +}; + +/// Thrown when extra values are found in an INI file +class ConfigError : public ParseError { + CLI11_ERROR_DEF(ParseError, ConfigError) + CLI11_ERROR_SIMPLE(ConfigError) + static ConfigError Extras(std::string item) { return ConfigError("INI was not able to parse " + item); } + static ConfigError NotConfigurable(std::string item) { + return ConfigError(item + ": This option is not allowed in a configuration file"); + } +}; + +/// Thrown when validation fails before parsing +class InvalidError : public ParseError { + CLI11_ERROR_DEF(ParseError, InvalidError) + explicit InvalidError(std::string name) + : InvalidError(name + ": Too many positional arguments with unlimited expected args", ExitCodes::InvalidError) { + } +}; + +/// This is just a safety check to verify selection and parsing match - you should not ever see it +/// Strings are directly added to this error, but again, it should never be seen. +class HorribleError : public ParseError { + CLI11_ERROR_DEF(ParseError, HorribleError) + CLI11_ERROR_SIMPLE(HorribleError) +}; + +// After parsing + +/// Thrown when counting a non-existent option +class OptionNotFound : public Error { + CLI11_ERROR_DEF(Error, OptionNotFound) + explicit OptionNotFound(std::string name) : OptionNotFound(name + " not found", ExitCodes::OptionNotFound) {} +}; + +#undef CLI11_ERROR_DEF +#undef CLI11_ERROR_SIMPLE + +/// @} + + + + +// Type tools + +// Utilities for type enabling +namespace detail { +// Based generally on https://rmf.io/cxx11/almost-static-if +/// Simple empty scoped class +enum class enabler {}; + +/// An instance to use in EnableIf +constexpr enabler dummy = {}; +} // namespace detail + +/// A copy of enable_if_t from C++14, compatible with C++11. +/// +/// We could check to see if C++14 is being used, but it does not hurt to redefine this +/// (even Google does this: https://github.com/google/skia/blob/main/include/private/SkTLogic.h) +/// It is not in the std namespace anyway, so no harm done. +template using enable_if_t = typename std::enable_if::type; + +/// A copy of std::void_t from C++17 (helper for C++11 and C++14) +template struct make_void { using type = void; }; + +/// A copy of std::void_t from C++17 - same reasoning as enable_if_t, it does not hurt to redefine +template using void_t = typename make_void::type; + +/// A copy of std::conditional_t from C++14 - same reasoning as enable_if_t, it does not hurt to redefine +template using conditional_t = typename std::conditional::type; + +/// Check to see if something is bool (fail check by default) +template struct is_bool : std::false_type {}; + +/// Check to see if something is bool (true if actually a bool) +template <> struct is_bool : std::true_type {}; + +/// Check to see if something is a shared pointer +template struct is_shared_ptr : std::false_type {}; + +/// Check to see if something is a shared pointer (True if really a shared pointer) +template struct is_shared_ptr> : std::true_type {}; + +/// Check to see if something is a shared pointer (True if really a shared pointer) +template struct is_shared_ptr> : std::true_type {}; + +/// Check to see if something is copyable pointer +template struct is_copyable_ptr { + static bool const value = is_shared_ptr::value || std::is_pointer::value; +}; + +/// This can be specialized to override the type deduction for IsMember. +template struct IsMemberType { using type = T; }; + +/// The main custom type needed here is const char * should be a string. +template <> struct IsMemberType { using type = std::string; }; + +namespace detail { + +// These are utilities for IsMember and other transforming objects + +/// Handy helper to access the element_type generically. This is not part of is_copyable_ptr because it requires that +/// pointer_traits be valid. + +/// not a pointer +template struct element_type { using type = T; }; + +template struct element_type::value>::type> { + using type = typename std::pointer_traits::element_type; +}; + +/// Combination of the element type and value type - remove pointer (including smart pointers) and get the value_type of +/// the container +template struct element_value_type { using type = typename element_type::type::value_type; }; + +/// Adaptor for set-like structure: This just wraps a normal container in a few utilities that do almost nothing. +template struct pair_adaptor : std::false_type { + using value_type = typename T::value_type; + using first_type = typename std::remove_const::type; + using second_type = typename std::remove_const::type; + + /// Get the first value (really just the underlying value) + template static auto first(Q &&pair_value) -> decltype(std::forward(pair_value)) { + return std::forward(pair_value); + } + /// Get the second value (really just the underlying value) + template static auto second(Q &&pair_value) -> decltype(std::forward(pair_value)) { + return std::forward(pair_value); + } +}; + +/// Adaptor for map-like structure (true version, must have key_type and mapped_type). +/// This wraps a mapped container in a few utilities access it in a general way. +template +struct pair_adaptor< + T, + conditional_t, void>> + : std::true_type { + using value_type = typename T::value_type; + using first_type = typename std::remove_const::type; + using second_type = typename std::remove_const::type; + + /// Get the first value (really just the underlying value) + template static auto first(Q &&pair_value) -> decltype(std::get<0>(std::forward(pair_value))) { + return std::get<0>(std::forward(pair_value)); + } + /// Get the second value (really just the underlying value) + template static auto second(Q &&pair_value) -> decltype(std::get<1>(std::forward(pair_value))) { + return std::get<1>(std::forward(pair_value)); + } +}; + +// Warning is suppressed due to "bug" in gcc<5.0 and gcc 7.0 with c++17 enabled that generates a Wnarrowing warning +// in the unevaluated context even if the function that was using this wasn't used. The standard says narrowing in +// brace initialization shouldn't be allowed but for backwards compatibility gcc allows it in some contexts. It is a +// little fuzzy what happens in template constructs and I think that was something GCC took a little while to work out. +// But regardless some versions of gcc generate a warning when they shouldn't from the following code so that should be +// suppressed +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wnarrowing" +#endif +// check for constructibility from a specific type and copy assignable used in the parse detection +template class is_direct_constructible { + template + static auto test(int, std::true_type) -> decltype( +// NVCC warns about narrowing conversions here +#ifdef __CUDACC__ +#pragma diag_suppress 2361 +#endif + TT { std::declval() } +#ifdef __CUDACC__ +#pragma diag_default 2361 +#endif + , + std::is_move_assignable()); + + template static auto test(int, std::false_type) -> std::false_type; + + template static auto test(...) -> std::false_type; + + public: + static constexpr bool value = decltype(test(0, typename std::is_constructible::type()))::value; +}; +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + +// Check for output streamability +// Based on https://stackoverflow.com/questions/22758291/how-can-i-detect-if-a-type-can-be-streamed-to-an-stdostream + +template class is_ostreamable { + template + static auto test(int) -> decltype(std::declval() << std::declval(), std::true_type()); + + template static auto test(...) -> std::false_type; + + public: + static constexpr bool value = decltype(test(0))::value; +}; + +/// Check for input streamability +template class is_istreamable { + template + static auto test(int) -> decltype(std::declval() >> std::declval(), std::true_type()); + + template static auto test(...) -> std::false_type; + + public: + static constexpr bool value = decltype(test(0))::value; +}; + +/// Check for complex +template class is_complex { + template + static auto test(int) -> decltype(std::declval().real(), std::declval().imag(), std::true_type()); + + template static auto test(...) -> std::false_type; + + public: + static constexpr bool value = decltype(test(0))::value; +}; + +/// Templated operation to get a value from a stream +template ::value, detail::enabler> = detail::dummy> +bool from_stream(const std::string &istring, T &obj) { + std::istringstream is; + is.str(istring); + is >> obj; + return !is.fail() && !is.rdbuf()->in_avail(); +} + +template ::value, detail::enabler> = detail::dummy> +bool from_stream(const std::string & /*istring*/, T & /*obj*/) { + return false; +} + +// check to see if an object is a mutable container (fail by default) +template struct is_mutable_container : std::false_type {}; + +/// type trait to test if a type is a mutable container meaning it has a value_type, it has an iterator, a clear, and +/// end methods and an insert function. And for our purposes we exclude std::string and types that can be constructed +/// from a std::string +template +struct is_mutable_container< + T, + conditional_t().end()), + decltype(std::declval().clear()), + decltype(std::declval().insert(std::declval().end())>(), + std::declval()))>, + void>> + : public conditional_t::value, std::false_type, std::true_type> {}; + +// check to see if an object is a mutable container (fail by default) +template struct is_readable_container : std::false_type {}; + +/// type trait to test if a type is a container meaning it has a value_type, it has an iterator, a clear, and an end +/// methods and an insert function. And for our purposes we exclude std::string and types that can be constructed from +/// a std::string +template +struct is_readable_container< + T, + conditional_t().end()), decltype(std::declval().begin())>, void>> + : public std::true_type {}; + +// check to see if an object is a wrapper (fail by default) +template struct is_wrapper : std::false_type {}; + +// check if an object is a wrapper (it has a value_type defined) +template +struct is_wrapper, void>> : public std::true_type {}; + +// Check for tuple like types, as in classes with a tuple_size type trait +template class is_tuple_like { + template + // static auto test(int) + // -> decltype(std::conditional<(std::tuple_size::value > 0), std::true_type, std::false_type>::type()); + static auto test(int) -> decltype(std::tuple_size::type>::value, std::true_type{}); + template static auto test(...) -> std::false_type; + + public: + static constexpr bool value = decltype(test(0))::value; +}; + +/// Convert an object to a string (directly forward if this can become a string) +template ::value, detail::enabler> = detail::dummy> +auto to_string(T &&value) -> decltype(std::forward(value)) { + return std::forward(value); +} + +/// Construct a string from the object +template ::value && !std::is_convertible::value, + detail::enabler> = detail::dummy> +std::string to_string(const T &value) { + return std::string(value); +} + +/// Convert an object to a string (streaming must be supported for that type) +template ::value && !std::is_constructible::value && + is_ostreamable::value, + detail::enabler> = detail::dummy> +std::string to_string(T &&value) { + std::stringstream stream; + stream << value; + return stream.str(); +} + +/// If conversion is not supported, return an empty string (streaming is not supported for that type) +template ::value && !is_ostreamable::value && + !is_readable_container::type>::value, + detail::enabler> = detail::dummy> +std::string to_string(T &&) { + return std::string{}; +} + +/// convert a readable container to a string +template ::value && !is_ostreamable::value && + is_readable_container::value, + detail::enabler> = detail::dummy> +std::string to_string(T &&variable) { + std::vector defaults; + auto cval = variable.begin(); + auto end = variable.end(); + while(cval != end) { + defaults.emplace_back(CLI::detail::to_string(*cval)); + ++cval; + } + return std::string("[" + detail::join(defaults) + "]"); +} + +/// special template overload +template ::value, detail::enabler> = detail::dummy> +auto checked_to_string(T &&value) -> decltype(to_string(std::forward(value))) { + return to_string(std::forward(value)); +} + +/// special template overload +template ::value, detail::enabler> = detail::dummy> +std::string checked_to_string(T &&) { + return std::string{}; +} +/// get a string as a convertible value for arithmetic types +template ::value, detail::enabler> = detail::dummy> +std::string value_string(const T &value) { + return std::to_string(value); +} +/// get a string as a convertible value for enumerations +template ::value, detail::enabler> = detail::dummy> +std::string value_string(const T &value) { + return std::to_string(static_cast::type>(value)); +} +/// for other types just use the regular to_string function +template ::value && !std::is_arithmetic::value, detail::enabler> = detail::dummy> +auto value_string(const T &value) -> decltype(to_string(value)) { + return to_string(value); +} + +/// template to get the underlying value type if it exists or use a default +template struct wrapped_type { using type = def; }; + +/// Type size for regular object types that do not look like a tuple +template struct wrapped_type::value>::type> { + using type = typename T::value_type; +}; + +/// This will only trigger for actual void type +template struct type_count_base { static const int value{0}; }; + +/// Type size for regular object types that do not look like a tuple +template +struct type_count_base::value && !is_mutable_container::value && + !std::is_void::value>::type> { + static constexpr int value{1}; +}; + +/// the base tuple size +template +struct type_count_base::value && !is_mutable_container::value>::type> { + static constexpr int value{std::tuple_size::value}; +}; + +/// Type count base for containers is the type_count_base of the individual element +template struct type_count_base::value>::type> { + static constexpr int value{type_count_base::value}; +}; + +/// Set of overloads to get the type size of an object + +/// forward declare the subtype_count structure +template struct subtype_count; + +/// forward declare the subtype_count_min structure +template struct subtype_count_min; + +/// This will only trigger for actual void type +template struct type_count { static const int value{0}; }; + +/// Type size for regular object types that do not look like a tuple +template +struct type_count::value && !is_tuple_like::value && !is_complex::value && + !std::is_void::value>::type> { + static constexpr int value{1}; +}; + +/// Type size for complex since it sometimes looks like a wrapper +template struct type_count::value>::type> { + static constexpr int value{2}; +}; + +/// Type size of types that are wrappers,except complex and tuples(which can also be wrappers sometimes) +template struct type_count::value>::type> { + static constexpr int value{subtype_count::value}; +}; + +/// Type size of types that are wrappers,except containers complex and tuples(which can also be wrappers sometimes) +template +struct type_count::value && !is_complex::value && !is_tuple_like::value && + !is_mutable_container::value>::type> { + static constexpr int value{type_count::value}; +}; + +/// 0 if the index > tuple size +template +constexpr typename std::enable_if::value, int>::type tuple_type_size() { + return 0; +} + +/// Recursively generate the tuple type name +template + constexpr typename std::enable_if < I::value, int>::type tuple_type_size() { + return subtype_count::type>::value + tuple_type_size(); +} + +/// Get the type size of the sum of type sizes for all the individual tuple types +template struct type_count::value>::type> { + static constexpr int value{tuple_type_size()}; +}; + +/// definition of subtype count +template struct subtype_count { + static constexpr int value{is_mutable_container::value ? expected_max_vector_size : type_count::value}; +}; + +/// This will only trigger for actual void type +template struct type_count_min { static const int value{0}; }; + +/// Type size for regular object types that do not look like a tuple +template +struct type_count_min< + T, + typename std::enable_if::value && !is_tuple_like::value && !is_wrapper::value && + !is_complex::value && !std::is_void::value>::type> { + static constexpr int value{type_count::value}; +}; + +/// Type size for complex since it sometimes looks like a wrapper +template struct type_count_min::value>::type> { + static constexpr int value{1}; +}; + +/// Type size min of types that are wrappers,except complex and tuples(which can also be wrappers sometimes) +template +struct type_count_min< + T, + typename std::enable_if::value && !is_complex::value && !is_tuple_like::value>::type> { + static constexpr int value{subtype_count_min::value}; +}; + +/// 0 if the index > tuple size +template +constexpr typename std::enable_if::value, int>::type tuple_type_size_min() { + return 0; +} + +/// Recursively generate the tuple type name +template + constexpr typename std::enable_if < I::value, int>::type tuple_type_size_min() { + return subtype_count_min::type>::value + tuple_type_size_min(); +} + +/// Get the type size of the sum of type sizes for all the individual tuple types +template struct type_count_min::value>::type> { + static constexpr int value{tuple_type_size_min()}; +}; + +/// definition of subtype count +template struct subtype_count_min { + static constexpr int value{is_mutable_container::value + ? ((type_count::value < expected_max_vector_size) ? type_count::value : 0) + : type_count_min::value}; +}; + +/// This will only trigger for actual void type +template struct expected_count { static const int value{0}; }; + +/// For most types the number of expected items is 1 +template +struct expected_count::value && !is_wrapper::value && + !std::is_void::value>::type> { + static constexpr int value{1}; +}; +/// number of expected items in a vector +template struct expected_count::value>::type> { + static constexpr int value{expected_max_vector_size}; +}; + +/// number of expected items in a vector +template +struct expected_count::value && is_wrapper::value>::type> { + static constexpr int value{expected_count::value}; +}; + +// Enumeration of the different supported categorizations of objects +enum class object_category : int { + char_value = 1, + integral_value = 2, + unsigned_integral = 4, + enumeration = 6, + boolean_value = 8, + floating_point = 10, + number_constructible = 12, + double_constructible = 14, + integer_constructible = 16, + // string like types + string_assignable = 23, + string_constructible = 24, + other = 45, + // special wrapper or container types + wrapper_value = 50, + complex_number = 60, + tuple_value = 70, + container_value = 80, + +}; + +/// Set of overloads to classify an object according to type + +/// some type that is not otherwise recognized +template struct classify_object { + static constexpr object_category value{object_category::other}; +}; + +/// Signed integers +template +struct classify_object< + T, + typename std::enable_if::value && !std::is_same::value && std::is_signed::value && + !is_bool::value && !std::is_enum::value>::type> { + static constexpr object_category value{object_category::integral_value}; +}; + +/// Unsigned integers +template +struct classify_object::value && std::is_unsigned::value && + !std::is_same::value && !is_bool::value>::type> { + static constexpr object_category value{object_category::unsigned_integral}; +}; + +/// single character values +template +struct classify_object::value && !std::is_enum::value>::type> { + static constexpr object_category value{object_category::char_value}; +}; + +/// Boolean values +template struct classify_object::value>::type> { + static constexpr object_category value{object_category::boolean_value}; +}; + +/// Floats +template struct classify_object::value>::type> { + static constexpr object_category value{object_category::floating_point}; +}; + +/// String and similar direct assignment +template +struct classify_object::value && !std::is_integral::value && + std::is_assignable::value>::type> { + static constexpr object_category value{object_category::string_assignable}; +}; + +/// String and similar constructible and copy assignment +template +struct classify_object< + T, + typename std::enable_if::value && !std::is_integral::value && + !std::is_assignable::value && (type_count::value == 1) && + std::is_constructible::value>::type> { + static constexpr object_category value{object_category::string_constructible}; +}; + +/// Enumerations +template struct classify_object::value>::type> { + static constexpr object_category value{object_category::enumeration}; +}; + +template struct classify_object::value>::type> { + static constexpr object_category value{object_category::complex_number}; +}; + +/// Handy helper to contain a bunch of checks that rule out many common types (integers, string like, floating point, +/// vectors, and enumerations +template struct uncommon_type { + using type = typename std::conditional::value && !std::is_integral::value && + !std::is_assignable::value && + !std::is_constructible::value && !is_complex::value && + !is_mutable_container::value && !std::is_enum::value, + std::true_type, + std::false_type>::type; + static constexpr bool value = type::value; +}; + +/// wrapper type +template +struct classify_object::value && is_wrapper::value && + !is_tuple_like::value && uncommon_type::value)>::type> { + static constexpr object_category value{object_category::wrapper_value}; +}; + +/// Assignable from double or int +template +struct classify_object::value && type_count::value == 1 && + !is_wrapper::value && is_direct_constructible::value && + is_direct_constructible::value>::type> { + static constexpr object_category value{object_category::number_constructible}; +}; + +/// Assignable from int +template +struct classify_object::value && type_count::value == 1 && + !is_wrapper::value && !is_direct_constructible::value && + is_direct_constructible::value>::type> { + static constexpr object_category value{object_category::integer_constructible}; +}; + +/// Assignable from double +template +struct classify_object::value && type_count::value == 1 && + !is_wrapper::value && is_direct_constructible::value && + !is_direct_constructible::value>::type> { + static constexpr object_category value{object_category::double_constructible}; +}; + +/// Tuple type +template +struct classify_object< + T, + typename std::enable_if::value && + ((type_count::value >= 2 && !is_wrapper::value) || + (uncommon_type::value && !is_direct_constructible::value && + !is_direct_constructible::value))>::type> { + static constexpr object_category value{object_category::tuple_value}; + // the condition on this class requires it be like a tuple, but on some compilers (like Xcode) tuples can be + // constructed from just the first element so tuples of can be constructed from a string, which + // could lead to issues so there are two variants of the condition, the first isolates things with a type size >=2 + // mainly to get tuples on Xcode with the exception of wrappers, the second is the main one and just separating out + // those cases that are caught by other object classifications +}; + +/// container type +template struct classify_object::value>::type> { + static constexpr object_category value{object_category::container_value}; +}; + +// Type name print + +/// Was going to be based on +/// http://stackoverflow.com/questions/1055452/c-get-name-of-type-in-template +/// But this is cleaner and works better in this case + +template ::value == object_category::char_value, detail::enabler> = detail::dummy> +constexpr const char *type_name() { + return "CHAR"; +} + +template ::value == object_category::integral_value || + classify_object::value == object_category::integer_constructible, + detail::enabler> = detail::dummy> +constexpr const char *type_name() { + return "INT"; +} + +template ::value == object_category::unsigned_integral, detail::enabler> = detail::dummy> +constexpr const char *type_name() { + return "UINT"; +} + +template ::value == object_category::floating_point || + classify_object::value == object_category::number_constructible || + classify_object::value == object_category::double_constructible, + detail::enabler> = detail::dummy> +constexpr const char *type_name() { + return "FLOAT"; +} + +/// Print name for enumeration types +template ::value == object_category::enumeration, detail::enabler> = detail::dummy> +constexpr const char *type_name() { + return "ENUM"; +} + +/// Print name for enumeration types +template ::value == object_category::boolean_value, detail::enabler> = detail::dummy> +constexpr const char *type_name() { + return "BOOLEAN"; +} + +/// Print name for enumeration types +template ::value == object_category::complex_number, detail::enabler> = detail::dummy> +constexpr const char *type_name() { + return "COMPLEX"; +} + +/// Print for all other types +template ::value >= object_category::string_assignable && + classify_object::value <= object_category::other, + detail::enabler> = detail::dummy> +constexpr const char *type_name() { + return "TEXT"; +} +/// typename for tuple value +template ::value == object_category::tuple_value && type_count_base::value >= 2, + detail::enabler> = detail::dummy> +std::string type_name(); // forward declaration + +/// Generate type name for a wrapper or container value +template ::value == object_category::container_value || + classify_object::value == object_category::wrapper_value, + detail::enabler> = detail::dummy> +std::string type_name(); // forward declaration + +/// Print name for single element tuple types +template ::value == object_category::tuple_value && type_count_base::value == 1, + detail::enabler> = detail::dummy> +inline std::string type_name() { + return type_name::type>::type>(); +} + +/// Empty string if the index > tuple size +template +inline typename std::enable_if::value, std::string>::type tuple_name() { + return std::string{}; +} + +/// Recursively generate the tuple type name +template +inline typename std::enable_if<(I < type_count_base::value), std::string>::type tuple_name() { + std::string str = std::string(type_name::type>::type>()) + + ',' + tuple_name(); + if(str.back() == ',') + str.pop_back(); + return str; +} + +/// Print type name for tuples with 2 or more elements +template ::value == object_category::tuple_value && type_count_base::value >= 2, + detail::enabler>> +inline std::string type_name() { + auto tname = std::string(1, '[') + tuple_name(); + tname.push_back(']'); + return tname; +} + +/// get the type name for a type that has a value_type member +template ::value == object_category::container_value || + classify_object::value == object_category::wrapper_value, + detail::enabler>> +inline std::string type_name() { + return type_name(); +} + +// Lexical cast + +/// Convert to an unsigned integral +template ::value, detail::enabler> = detail::dummy> +bool integral_conversion(const std::string &input, T &output) noexcept { + if(input.empty()) { + return false; + } + char *val = nullptr; + std::uint64_t output_ll = std::strtoull(input.c_str(), &val, 0); + output = static_cast(output_ll); + return val == (input.c_str() + input.size()) && static_cast(output) == output_ll; +} + +/// Convert to a signed integral +template ::value, detail::enabler> = detail::dummy> +bool integral_conversion(const std::string &input, T &output) noexcept { + if(input.empty()) { + return false; + } + char *val = nullptr; + std::int64_t output_ll = std::strtoll(input.c_str(), &val, 0); + output = static_cast(output_ll); + return val == (input.c_str() + input.size()) && static_cast(output) == output_ll; +} + +/// Convert a flag into an integer value typically binary flags +inline std::int64_t to_flag_value(std::string val) { + static const std::string trueString("true"); + static const std::string falseString("false"); + if(val == trueString) { + return 1; + } + if(val == falseString) { + return -1; + } + val = detail::to_lower(val); + std::int64_t ret; + if(val.size() == 1) { + if(val[0] >= '1' && val[0] <= '9') { + return (static_cast(val[0]) - '0'); + } + switch(val[0]) { + case '0': + case 'f': + case 'n': + case '-': + ret = -1; + break; + case 't': + case 'y': + case '+': + ret = 1; + break; + default: + throw std::invalid_argument("unrecognized character"); + } + return ret; + } + if(val == trueString || val == "on" || val == "yes" || val == "enable") { + ret = 1; + } else if(val == falseString || val == "off" || val == "no" || val == "disable") { + ret = -1; + } else { + ret = std::stoll(val); + } + return ret; +} + +/// Integer conversion +template ::value == object_category::integral_value || + classify_object::value == object_category::unsigned_integral, + detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + return integral_conversion(input, output); +} + +/// char values +template ::value == object_category::char_value, detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + if(input.size() == 1) { + output = static_cast(input[0]); + return true; + } + return integral_conversion(input, output); +} + +/// Boolean values +template ::value == object_category::boolean_value, detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + try { + auto out = to_flag_value(input); + output = (out > 0); + return true; + } catch(const std::invalid_argument &) { + return false; + } catch(const std::out_of_range &) { + // if the number is out of the range of a 64 bit value then it is still a number and for this purpose is still + // valid all we care about the sign + output = (input[0] != '-'); + return true; + } +} + +/// Floats +template ::value == object_category::floating_point, detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + if(input.empty()) { + return false; + } + char *val = nullptr; + auto output_ld = std::strtold(input.c_str(), &val); + output = static_cast(output_ld); + return val == (input.c_str() + input.size()); +} + +/// complex +template ::value == object_category::complex_number, detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + using XC = typename wrapped_type::type; + XC x{0.0}, y{0.0}; + auto str1 = input; + bool worked = false; + auto nloc = str1.find_last_of("+-"); + if(nloc != std::string::npos && nloc > 0) { + worked = detail::lexical_cast(str1.substr(0, nloc), x); + str1 = str1.substr(nloc); + if(str1.back() == 'i' || str1.back() == 'j') + str1.pop_back(); + worked = worked && detail::lexical_cast(str1, y); + } else { + if(str1.back() == 'i' || str1.back() == 'j') { + str1.pop_back(); + worked = detail::lexical_cast(str1, y); + x = XC{0}; + } else { + worked = detail::lexical_cast(str1, x); + y = XC{0}; + } + } + if(worked) { + output = T{x, y}; + return worked; + } + return from_stream(input, output); +} + +/// String and similar direct assignment +template ::value == object_category::string_assignable, detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + output = input; + return true; +} + +/// String and similar constructible and copy assignment +template < + typename T, + enable_if_t::value == object_category::string_constructible, detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + output = T(input); + return true; +} + +/// Enumerations +template ::value == object_category::enumeration, detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + typename std::underlying_type::type val; + if(!integral_conversion(input, val)) { + return false; + } + output = static_cast(val); + return true; +} + +/// wrapper types +template ::value == object_category::wrapper_value && + std::is_assignable::value, + detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + typename T::value_type val; + if(lexical_cast(input, val)) { + output = val; + return true; + } + return from_stream(input, output); +} + +template ::value == object_category::wrapper_value && + !std::is_assignable::value && std::is_assignable::value, + detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + typename T::value_type val; + if(lexical_cast(input, val)) { + output = T{val}; + return true; + } + return from_stream(input, output); +} + +/// Assignable from double or int +template < + typename T, + enable_if_t::value == object_category::number_constructible, detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + int val; + if(integral_conversion(input, val)) { + output = T(val); + return true; + } else { + double dval; + if(lexical_cast(input, dval)) { + output = T{dval}; + return true; + } + } + return from_stream(input, output); +} + +/// Assignable from int +template < + typename T, + enable_if_t::value == object_category::integer_constructible, detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + int val; + if(integral_conversion(input, val)) { + output = T(val); + return true; + } + return from_stream(input, output); +} + +/// Assignable from double +template < + typename T, + enable_if_t::value == object_category::double_constructible, detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + double val; + if(lexical_cast(input, val)) { + output = T{val}; + return true; + } + return from_stream(input, output); +} + +/// Non-string convertible from an int +template ::value == object_category::other && std::is_assignable::value, + detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + int val; + if(integral_conversion(input, val)) { +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4800) +#endif + // with Atomic this could produce a warning due to the conversion but if atomic gets here it is an old style + // so will most likely still work + output = val; +#ifdef _MSC_VER +#pragma warning(pop) +#endif + return true; + } + // LCOV_EXCL_START + // This version of cast is only used for odd cases in an older compilers the fail over + // from_stream is tested elsewhere an not relevant for coverage here + return from_stream(input, output); + // LCOV_EXCL_STOP +} + +/// Non-string parsable by a stream +template ::value == object_category::other && !std::is_assignable::value, + detail::enabler> = detail::dummy> +bool lexical_cast(const std::string &input, T &output) { + static_assert(is_istreamable::value, + "option object type must have a lexical cast overload or streaming input operator(>>) defined, if it " + "is convertible from another type use the add_option(...) with XC being the known type"); + return from_stream(input, output); +} + +/// Assign a value through lexical cast operations +/// Strings can be empty so we need to do a little different +template ::value && + (classify_object::value == object_category::string_assignable || + classify_object::value == object_category::string_constructible), + detail::enabler> = detail::dummy> +bool lexical_assign(const std::string &input, AssignTo &output) { + return lexical_cast(input, output); +} + +/// Assign a value through lexical cast operations +template ::value && std::is_assignable::value && + classify_object::value != object_category::string_assignable && + classify_object::value != object_category::string_constructible, + detail::enabler> = detail::dummy> +bool lexical_assign(const std::string &input, AssignTo &output) { + if(input.empty()) { + output = AssignTo{}; + return true; + } + + return lexical_cast(input, output); +} + +/// Assign a value through lexical cast operations +template ::value && !std::is_assignable::value && + classify_object::value == object_category::wrapper_value, + detail::enabler> = detail::dummy> +bool lexical_assign(const std::string &input, AssignTo &output) { + if(input.empty()) { + typename AssignTo::value_type emptyVal{}; + output = emptyVal; + return true; + } + return lexical_cast(input, output); +} + +/// Assign a value through lexical cast operations for int compatible values +/// mainly for atomic operations on some compilers +template ::value && !std::is_assignable::value && + classify_object::value != object_category::wrapper_value && + std::is_assignable::value, + detail::enabler> = detail::dummy> +bool lexical_assign(const std::string &input, AssignTo &output) { + if(input.empty()) { + output = 0; + return true; + } + int val; + if(lexical_cast(input, val)) { + output = val; + return true; + } + return false; +} + +/// Assign a value converted from a string in lexical cast to the output value directly +template ::value && std::is_assignable::value, + detail::enabler> = detail::dummy> +bool lexical_assign(const std::string &input, AssignTo &output) { + ConvertTo val{}; + bool parse_result = (!input.empty()) ? lexical_cast(input, val) : true; + if(parse_result) { + output = val; + } + return parse_result; +} + +/// Assign a value from a lexical cast through constructing a value and move assigning it +template < + typename AssignTo, + typename ConvertTo, + enable_if_t::value && !std::is_assignable::value && + std::is_move_assignable::value, + detail::enabler> = detail::dummy> +bool lexical_assign(const std::string &input, AssignTo &output) { + ConvertTo val{}; + bool parse_result = input.empty() ? true : lexical_cast(input, val); + if(parse_result) { + output = AssignTo(val); // use () form of constructor to allow some implicit conversions + } + return parse_result; +} + +/// primary lexical conversion operation, 1 string to 1 type of some kind +template ::value <= object_category::other && + classify_object::value <= object_category::wrapper_value, + detail::enabler> = detail::dummy> +bool lexical_conversion(const std::vector &strings, AssignTo &output) { + return lexical_assign(strings[0], output); +} + +/// Lexical conversion if there is only one element but the conversion type is for two, then call a two element +/// constructor +template ::value <= 2) && expected_count::value == 1 && + is_tuple_like::value && type_count_base::value == 2, + detail::enabler> = detail::dummy> +bool lexical_conversion(const std::vector &strings, AssignTo &output) { + // the remove const is to handle pair types coming from a container + typename std::remove_const::type>::type v1; + typename std::tuple_element<1, ConvertTo>::type v2; + bool retval = lexical_assign(strings[0], v1); + if(strings.size() > 1) { + retval = retval && lexical_assign(strings[1], v2); + } + if(retval) { + output = AssignTo{v1, v2}; + } + return retval; +} + +/// Lexical conversion of a container types of single elements +template ::value && is_mutable_container::value && + type_count::value == 1, + detail::enabler> = detail::dummy> +bool lexical_conversion(const std::vector &strings, AssignTo &output) { + output.erase(output.begin(), output.end()); + for(const auto &elem : strings) { + typename AssignTo::value_type out; + bool retval = lexical_assign(elem, out); + if(!retval) { + return false; + } + output.insert(output.end(), std::move(out)); + } + return (!output.empty()); +} + +/// Lexical conversion for complex types +template ::value, detail::enabler> = detail::dummy> +bool lexical_conversion(const std::vector &strings, AssignTo &output) { + + if(strings.size() >= 2 && !strings[1].empty()) { + using XC2 = typename wrapped_type::type; + XC2 x{0.0}, y{0.0}; + auto str1 = strings[1]; + if(str1.back() == 'i' || str1.back() == 'j') { + str1.pop_back(); + } + auto worked = detail::lexical_cast(strings[0], x) && detail::lexical_cast(str1, y); + if(worked) { + output = ConvertTo{x, y}; + } + return worked; + } else { + return lexical_assign(strings[0], output); + } +} + +/// Conversion to a vector type using a particular single type as the conversion type +template ::value && (expected_count::value == 1) && + (type_count::value == 1), + detail::enabler> = detail::dummy> +bool lexical_conversion(const std::vector &strings, AssignTo &output) { + bool retval = true; + output.clear(); + output.reserve(strings.size()); + for(const auto &elem : strings) { + + output.emplace_back(); + retval = retval && lexical_assign(elem, output.back()); + } + return (!output.empty()) && retval; +} + +// forward declaration + +/// Lexical conversion of a container types with conversion type of two elements +template ::value && is_mutable_container::value && + type_count_base::value == 2, + detail::enabler> = detail::dummy> +bool lexical_conversion(std::vector strings, AssignTo &output); + +/// Lexical conversion of a vector types with type_size >2 forward declaration +template ::value && is_mutable_container::value && + type_count_base::value != 2 && + ((type_count::value > 2) || + (type_count::value > type_count_base::value)), + detail::enabler> = detail::dummy> +bool lexical_conversion(const std::vector &strings, AssignTo &output); + +/// Conversion for tuples +template ::value && is_tuple_like::value && + (type_count_base::value != type_count::value || + type_count::value > 2), + detail::enabler> = detail::dummy> +bool lexical_conversion(const std::vector &strings, AssignTo &output); // forward declaration + +/// Conversion for operations where the assigned type is some class but the conversion is a mutable container or large +/// tuple +template ::value && !is_mutable_container::value && + classify_object::value != object_category::wrapper_value && + (is_mutable_container::value || type_count::value > 2), + detail::enabler> = detail::dummy> +bool lexical_conversion(const std::vector &strings, AssignTo &output) { + + if(strings.size() > 1 || (!strings.empty() && !(strings.front().empty()))) { + ConvertTo val; + auto retval = lexical_conversion(strings, val); + output = AssignTo{val}; + return retval; + } + output = AssignTo{}; + return true; +} + +/// function template for converting tuples if the static Index is greater than the tuple size +template +inline typename std::enable_if<(I >= type_count_base::value), bool>::type +tuple_conversion(const std::vector &, AssignTo &) { + return true; +} + +/// Conversion of a tuple element where the type size ==1 and not a mutable container +template +inline typename std::enable_if::value && type_count::value == 1, bool>::type +tuple_type_conversion(std::vector &strings, AssignTo &output) { + auto retval = lexical_assign(strings[0], output); + strings.erase(strings.begin()); + return retval; +} + +/// Conversion of a tuple element where the type size !=1 but the size is fixed and not a mutable container +template +inline typename std::enable_if::value && (type_count::value > 1) && + type_count::value == type_count_min::value, + bool>::type +tuple_type_conversion(std::vector &strings, AssignTo &output) { + auto retval = lexical_conversion(strings, output); + strings.erase(strings.begin(), strings.begin() + type_count::value); + return retval; +} + +/// Conversion of a tuple element where the type is a mutable container or a type with different min and max type sizes +template +inline typename std::enable_if::value || + type_count::value != type_count_min::value, + bool>::type +tuple_type_conversion(std::vector &strings, AssignTo &output) { + + std::size_t index{subtype_count_min::value}; + const std::size_t mx_count{subtype_count::value}; + const std::size_t mx{(std::max)(mx_count, strings.size())}; + + while(index < mx) { + if(is_separator(strings[index])) { + break; + } + ++index; + } + bool retval = lexical_conversion( + std::vector(strings.begin(), strings.begin() + static_cast(index)), output); + strings.erase(strings.begin(), strings.begin() + static_cast(index) + 1); + return retval; +} + +/// Tuple conversion operation +template +inline typename std::enable_if<(I < type_count_base::value), bool>::type +tuple_conversion(std::vector strings, AssignTo &output) { + bool retval = true; + using ConvertToElement = typename std:: + conditional::value, typename std::tuple_element::type, ConvertTo>::type; + if(!strings.empty()) { + retval = retval && tuple_type_conversion::type, ConvertToElement>( + strings, std::get(output)); + } + retval = retval && tuple_conversion(std::move(strings), output); + return retval; +} + +/// Lexical conversion of a container types with tuple elements of size 2 +template ::value && is_mutable_container::value && + type_count_base::value == 2, + detail::enabler>> +bool lexical_conversion(std::vector strings, AssignTo &output) { + output.clear(); + while(!strings.empty()) { + + typename std::remove_const::type>::type v1; + typename std::tuple_element<1, typename ConvertTo::value_type>::type v2; + bool retval = tuple_type_conversion(strings, v1); + if(!strings.empty()) { + retval = retval && tuple_type_conversion(strings, v2); + } + if(retval) { + output.insert(output.end(), typename AssignTo::value_type{v1, v2}); + } else { + return false; + } + } + return (!output.empty()); +} + +/// lexical conversion of tuples with type count>2 or tuples of types of some element with a type size>=2 +template ::value && is_tuple_like::value && + (type_count_base::value != type_count::value || + type_count::value > 2), + detail::enabler>> +bool lexical_conversion(const std::vector &strings, AssignTo &output) { + static_assert( + !is_tuple_like::value || type_count_base::value == type_count_base::value, + "if the conversion type is defined as a tuple it must be the same size as the type you are converting to"); + return tuple_conversion(strings, output); +} + +/// Lexical conversion of a vector types for everything but tuples of two elements and types of size 1 +template ::value && is_mutable_container::value && + type_count_base::value != 2 && + ((type_count::value > 2) || + (type_count::value > type_count_base::value)), + detail::enabler>> +bool lexical_conversion(const std::vector &strings, AssignTo &output) { + bool retval = true; + output.clear(); + std::vector temp; + std::size_t ii{0}; + std::size_t icount{0}; + std::size_t xcm{type_count::value}; + auto ii_max = strings.size(); + while(ii < ii_max) { + temp.push_back(strings[ii]); + ++ii; + ++icount; + if(icount == xcm || is_separator(temp.back()) || ii == ii_max) { + if(static_cast(xcm) > type_count_min::value && is_separator(temp.back())) { + temp.pop_back(); + } + typename AssignTo::value_type temp_out; + retval = retval && + lexical_conversion(temp, temp_out); + temp.clear(); + if(!retval) { + return false; + } + output.insert(output.end(), std::move(temp_out)); + icount = 0; + } + } + return retval; +} + +/// conversion for wrapper types +template ::value == object_category::wrapper_value && + std::is_assignable::value, + detail::enabler> = detail::dummy> +bool lexical_conversion(const std::vector &strings, AssignTo &output) { + if(strings.empty() || strings.front().empty()) { + output = ConvertTo{}; + return true; + } + typename ConvertTo::value_type val; + if(lexical_conversion(strings, val)) { + output = ConvertTo{val}; + return true; + } + return false; +} + +/// conversion for wrapper types +template ::value == object_category::wrapper_value && + !std::is_assignable::value, + detail::enabler> = detail::dummy> +bool lexical_conversion(const std::vector &strings, AssignTo &output) { + using ConvertType = typename ConvertTo::value_type; + if(strings.empty() || strings.front().empty()) { + output = ConvertType{}; + return true; + } + ConvertType val; + if(lexical_conversion(strings, val)) { + output = val; + return true; + } + return false; +} + +/// Sum a vector of flag representations +/// The flag vector produces a series of strings in a vector, simple true is represented by a "1", simple false is +/// by +/// "-1" an if numbers are passed by some fashion they are captured as well so the function just checks for the most +/// common true and false strings then uses stoll to convert the rest for summing +template ::value, detail::enabler> = detail::dummy> +void sum_flag_vector(const std::vector &flags, T &output) { + std::int64_t count{0}; + for(auto &flag : flags) { + count += detail::to_flag_value(flag); + } + output = (count > 0) ? static_cast(count) : T{0}; +} + +/// Sum a vector of flag representations +/// The flag vector produces a series of strings in a vector, simple true is represented by a "1", simple false is +/// by +/// "-1" an if numbers are passed by some fashion they are captured as well so the function just checks for the most +/// common true and false strings then uses stoll to convert the rest for summing +template ::value, detail::enabler> = detail::dummy> +void sum_flag_vector(const std::vector &flags, T &output) { + std::int64_t count{0}; + for(auto &flag : flags) { + count += detail::to_flag_value(flag); + } + output = static_cast(count); +} + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4800) +#endif +// with Atomic this could produce a warning due to the conversion but if atomic gets here it is an old style so will +// most likely still work + +/// Sum a vector of flag representations +/// The flag vector produces a series of strings in a vector, simple true is represented by a "1", simple false is +/// by +/// "-1" an if numbers are passed by some fashion they are captured as well so the function just checks for the most +/// common true and false strings then uses stoll to convert the rest for summing +template ::value && !std::is_unsigned::value, detail::enabler> = detail::dummy> +void sum_flag_vector(const std::vector &flags, T &output) { + std::int64_t count{0}; + for(auto &flag : flags) { + count += detail::to_flag_value(flag); + } + std::string out = detail::to_string(count); + lexical_cast(out, output); +} + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +} // namespace detail + + + +namespace detail { + +// Returns false if not a short option. Otherwise, sets opt name and rest and returns true +inline bool split_short(const std::string ¤t, std::string &name, std::string &rest) { + if(current.size() > 1 && current[0] == '-' && valid_first_char(current[1])) { + name = current.substr(1, 1); + rest = current.substr(2); + return true; + } + return false; +} + +// Returns false if not a long option. Otherwise, sets opt name and other side of = and returns true +inline bool split_long(const std::string ¤t, std::string &name, std::string &value) { + if(current.size() > 2 && current.substr(0, 2) == "--" && valid_first_char(current[2])) { + auto loc = current.find_first_of('='); + if(loc != std::string::npos) { + name = current.substr(2, loc - 2); + value = current.substr(loc + 1); + } else { + name = current.substr(2); + value = ""; + } + return true; + } + return false; +} + +// Returns false if not a windows style option. Otherwise, sets opt name and value and returns true +inline bool split_windows_style(const std::string ¤t, std::string &name, std::string &value) { + if(current.size() > 1 && current[0] == '/' && valid_first_char(current[1])) { + auto loc = current.find_first_of(':'); + if(loc != std::string::npos) { + name = current.substr(1, loc - 1); + value = current.substr(loc + 1); + } else { + name = current.substr(1); + value = ""; + } + return true; + } + return false; +} + +// Splits a string into multiple long and short names +inline std::vector split_names(std::string current) { + std::vector output; + std::size_t val; + while((val = current.find(",")) != std::string::npos) { + output.push_back(trim_copy(current.substr(0, val))); + current = current.substr(val + 1); + } + output.push_back(trim_copy(current)); + return output; +} + +/// extract default flag values either {def} or starting with a ! +inline std::vector> get_default_flag_values(const std::string &str) { + std::vector flags = split_names(str); + flags.erase(std::remove_if(flags.begin(), + flags.end(), + [](const std::string &name) { + return ((name.empty()) || (!(((name.find_first_of('{') != std::string::npos) && + (name.back() == '}')) || + (name[0] == '!')))); + }), + flags.end()); + std::vector> output; + output.reserve(flags.size()); + for(auto &flag : flags) { + auto def_start = flag.find_first_of('{'); + std::string defval = "false"; + if((def_start != std::string::npos) && (flag.back() == '}')) { + defval = flag.substr(def_start + 1); + defval.pop_back(); + flag.erase(def_start, std::string::npos); + } + flag.erase(0, flag.find_first_not_of("-!")); + output.emplace_back(flag, defval); + } + return output; +} + +/// Get a vector of short names, one of long names, and a single name +inline std::tuple, std::vector, std::string> +get_names(const std::vector &input) { + + std::vector short_names; + std::vector long_names; + std::string pos_name; + + for(std::string name : input) { + if(name.length() == 0) { + continue; + } + if(name.length() > 1 && name[0] == '-' && name[1] != '-') { + if(name.length() == 2 && valid_first_char(name[1])) + short_names.emplace_back(1, name[1]); + else + throw BadNameString::OneCharName(name); + } else if(name.length() > 2 && name.substr(0, 2) == "--") { + name = name.substr(2); + if(valid_name_string(name)) + long_names.push_back(name); + else + throw BadNameString::BadLongName(name); + } else if(name == "-" || name == "--") { + throw BadNameString::DashesOnly(name); + } else { + if(pos_name.length() > 0) + throw BadNameString::MultiPositionalNames(name); + pos_name = name; + } + } + + return std::tuple, std::vector, std::string>( + short_names, long_names, pos_name); +} + +} // namespace detail + + + +class App; + +/// Holds values to load into Options +struct ConfigItem { + /// This is the list of parents + std::vector parents{}; + + /// This is the name + std::string name{}; + + /// Listing of inputs + std::vector inputs{}; + + /// The list of parents and name joined by "." + std::string fullname() const { + std::vector tmp = parents; + tmp.emplace_back(name); + return detail::join(tmp, "."); + } +}; + +/// This class provides a converter for configuration files. +class Config { + protected: + std::vector items{}; + + public: + /// Convert an app into a configuration + virtual std::string to_config(const App *, bool, bool, std::string) const = 0; + + /// Convert a configuration into an app + virtual std::vector from_config(std::istream &) const = 0; + + /// Get a flag value + virtual std::string to_flag(const ConfigItem &item) const { + if(item.inputs.size() == 1) { + return item.inputs.at(0); + } + throw ConversionError::TooManyInputsFlag(item.fullname()); + } + + /// Parse a config file, throw an error (ParseError:ConfigParseError or FileError) on failure + std::vector from_file(const std::string &name) { + std::ifstream input{name}; + if(!input.good()) + throw FileError::Missing(name); + + return from_config(input); + } + + /// Virtual destructor + virtual ~Config() = default; +}; + +/// This converter works with INI/TOML files; to write INI files use ConfigINI +class ConfigBase : public Config { + protected: + /// the character used for comments + char commentChar = '#'; + /// the character used to start an array '\0' is a default to not use + char arrayStart = '['; + /// the character used to end an array '\0' is a default to not use + char arrayEnd = ']'; + /// the character used to separate elements in an array + char arraySeparator = ','; + /// the character used separate the name from the value + char valueDelimiter = '='; + /// the character to use around strings + char stringQuote = '"'; + /// the character to use around single characters + char characterQuote = '\''; + /// the maximum number of layers to allow + uint8_t maximumLayers{255}; + /// the separator used to separator parent layers + char parentSeparatorChar{'.'}; + /// Specify the configuration index to use for arrayed sections + int16_t configIndex{-1}; + /// Specify the configuration section that should be used + std::string configSection{}; + + public: + std::string + to_config(const App * /*app*/, bool default_also, bool write_description, std::string prefix) const override; + + std::vector from_config(std::istream &input) const override; + /// Specify the configuration for comment characters + ConfigBase *comment(char cchar) { + commentChar = cchar; + return this; + } + /// Specify the start and end characters for an array + ConfigBase *arrayBounds(char aStart, char aEnd) { + arrayStart = aStart; + arrayEnd = aEnd; + return this; + } + /// Specify the delimiter character for an array + ConfigBase *arrayDelimiter(char aSep) { + arraySeparator = aSep; + return this; + } + /// Specify the delimiter between a name and value + ConfigBase *valueSeparator(char vSep) { + valueDelimiter = vSep; + return this; + } + /// Specify the quote characters used around strings and characters + ConfigBase *quoteCharacter(char qString, char qChar) { + stringQuote = qString; + characterQuote = qChar; + return this; + } + /// Specify the maximum number of parents + ConfigBase *maxLayers(uint8_t layers) { + maximumLayers = layers; + return this; + } + /// Specify the separator to use for parent layers + ConfigBase *parentSeparator(char sep) { + parentSeparatorChar = sep; + return this; + } + /// get a reference to the configuration section + std::string §ionRef() { return configSection; } + /// get the section + const std::string §ion() const { return configSection; } + /// specify a particular section of the configuration file to use + ConfigBase *section(const std::string §ionName) { + configSection = sectionName; + return this; + } + + /// get a reference to the configuration index + int16_t &indexRef() { return configIndex; } + /// get the section index + int16_t index() const { return configIndex; } + /// specify a particular index in the section to use (-1) for all sections to use + ConfigBase *index(int16_t sectionIndex) { + configIndex = sectionIndex; + return this; + } +}; + +/// the default Config is the TOML file format +using ConfigTOML = ConfigBase; + +/// ConfigINI generates a "standard" INI compliant output +class ConfigINI : public ConfigTOML { + + public: + ConfigINI() { + commentChar = ';'; + arrayStart = '\0'; + arrayEnd = '\0'; + arraySeparator = ' '; + valueDelimiter = '='; + } +}; + + + +class Option; + +/// @defgroup validator_group Validators + +/// @brief Some validators that are provided +/// +/// These are simple `std::string(const std::string&)` validators that are useful. They return +/// a string if the validation fails. A custom struct is provided, as well, with the same user +/// semantics, but with the ability to provide a new type name. +/// @{ + +/// +class Validator { + protected: + /// This is the description function, if empty the description_ will be used + std::function desc_function_{[]() { return std::string{}; }}; + + /// This is the base function that is to be called. + /// Returns a string error message if validation fails. + std::function func_{[](std::string &) { return std::string{}; }}; + /// The name for search purposes of the Validator + std::string name_{}; + /// A Validator will only apply to an indexed value (-1 is all elements) + int application_index_ = -1; + /// Enable for Validator to allow it to be disabled if need be + bool active_{true}; + /// specify that a validator should not modify the input + bool non_modifying_{false}; + + public: + Validator() = default; + /// Construct a Validator with just the description string + explicit Validator(std::string validator_desc) : desc_function_([validator_desc]() { return validator_desc; }) {} + /// Construct Validator from basic information + Validator(std::function op, std::string validator_desc, std::string validator_name = "") + : desc_function_([validator_desc]() { return validator_desc; }), func_(std::move(op)), + name_(std::move(validator_name)) {} + /// Set the Validator operation function + Validator &operation(std::function op) { + func_ = std::move(op); + return *this; + } + /// This is the required operator for a Validator - provided to help + /// users (CLI11 uses the member `func` directly) + std::string operator()(std::string &str) const { + std::string retstring; + if(active_) { + if(non_modifying_) { + std::string value = str; + retstring = func_(value); + } else { + retstring = func_(str); + } + } + return retstring; + } + + /// This is the required operator for a Validator - provided to help + /// users (CLI11 uses the member `func` directly) + std::string operator()(const std::string &str) const { + std::string value = str; + return (active_) ? func_(value) : std::string{}; + } + + /// Specify the type string + Validator &description(std::string validator_desc) { + desc_function_ = [validator_desc]() { return validator_desc; }; + return *this; + } + /// Specify the type string + Validator description(std::string validator_desc) const { + Validator newval(*this); + newval.desc_function_ = [validator_desc]() { return validator_desc; }; + return newval; + } + /// Generate type description information for the Validator + std::string get_description() const { + if(active_) { + return desc_function_(); + } + return std::string{}; + } + /// Specify the type string + Validator &name(std::string validator_name) { + name_ = std::move(validator_name); + return *this; + } + /// Specify the type string + Validator name(std::string validator_name) const { + Validator newval(*this); + newval.name_ = std::move(validator_name); + return newval; + } + /// Get the name of the Validator + const std::string &get_name() const { return name_; } + /// Specify whether the Validator is active or not + Validator &active(bool active_val = true) { + active_ = active_val; + return *this; + } + /// Specify whether the Validator is active or not + Validator active(bool active_val = true) const { + Validator newval(*this); + newval.active_ = active_val; + return newval; + } + + /// Specify whether the Validator can be modifying or not + Validator &non_modifying(bool no_modify = true) { + non_modifying_ = no_modify; + return *this; + } + /// Specify the application index of a validator + Validator &application_index(int app_index) { + application_index_ = app_index; + return *this; + } + /// Specify the application index of a validator + Validator application_index(int app_index) const { + Validator newval(*this); + newval.application_index_ = app_index; + return newval; + } + /// Get the current value of the application index + int get_application_index() const { return application_index_; } + /// Get a boolean if the validator is active + bool get_active() const { return active_; } + + /// Get a boolean if the validator is allowed to modify the input returns true if it can modify the input + bool get_modifying() const { return !non_modifying_; } + + /// Combining validators is a new validator. Type comes from left validator if function, otherwise only set if the + /// same. + Validator operator&(const Validator &other) const { + Validator newval; + + newval._merge_description(*this, other, " AND "); + + // Give references (will make a copy in lambda function) + const std::function &f1 = func_; + const std::function &f2 = other.func_; + + newval.func_ = [f1, f2](std::string &input) { + std::string s1 = f1(input); + std::string s2 = f2(input); + if(!s1.empty() && !s2.empty()) + return std::string("(") + s1 + ") AND (" + s2 + ")"; + else + return s1 + s2; + }; + + newval.active_ = (active_ & other.active_); + newval.application_index_ = application_index_; + return newval; + } + + /// Combining validators is a new validator. Type comes from left validator if function, otherwise only set if the + /// same. + Validator operator|(const Validator &other) const { + Validator newval; + + newval._merge_description(*this, other, " OR "); + + // Give references (will make a copy in lambda function) + const std::function &f1 = func_; + const std::function &f2 = other.func_; + + newval.func_ = [f1, f2](std::string &input) { + std::string s1 = f1(input); + std::string s2 = f2(input); + if(s1.empty() || s2.empty()) + return std::string(); + + return std::string("(") + s1 + ") OR (" + s2 + ")"; + }; + newval.active_ = (active_ & other.active_); + newval.application_index_ = application_index_; + return newval; + } + + /// Create a validator that fails when a given validator succeeds + Validator operator!() const { + Validator newval; + const std::function &dfunc1 = desc_function_; + newval.desc_function_ = [dfunc1]() { + auto str = dfunc1(); + return (!str.empty()) ? std::string("NOT ") + str : std::string{}; + }; + // Give references (will make a copy in lambda function) + const std::function &f1 = func_; + + newval.func_ = [f1, dfunc1](std::string &test) -> std::string { + std::string s1 = f1(test); + if(s1.empty()) { + return std::string("check ") + dfunc1() + " succeeded improperly"; + } + return std::string{}; + }; + newval.active_ = active_; + newval.application_index_ = application_index_; + return newval; + } + + private: + void _merge_description(const Validator &val1, const Validator &val2, const std::string &merger) { + + const std::function &dfunc1 = val1.desc_function_; + const std::function &dfunc2 = val2.desc_function_; + + desc_function_ = [=]() { + std::string f1 = dfunc1(); + std::string f2 = dfunc2(); + if((f1.empty()) || (f2.empty())) { + return f1 + f2; + } + return std::string(1, '(') + f1 + ')' + merger + '(' + f2 + ')'; + }; + } +}; // namespace CLI + +/// Class wrapping some of the accessors of Validator +class CustomValidator : public Validator { + public: +}; +// The implementation of the built in validators is using the Validator class; +// the user is only expected to use the const (static) versions (since there's no setup). +// Therefore, this is in detail. +namespace detail { + +/// CLI enumeration of different file types +enum class path_type { nonexistent, file, directory }; + +#if defined CLI11_HAS_FILESYSTEM && CLI11_HAS_FILESYSTEM > 0 +/// get the type of the path from a file name +inline path_type check_path(const char *file) noexcept { + std::error_code ec; + auto stat = std::filesystem::status(file, ec); + if(ec) { + return path_type::nonexistent; + } + switch(stat.type()) { + case std::filesystem::file_type::none: + case std::filesystem::file_type::not_found: + return path_type::nonexistent; + case std::filesystem::file_type::directory: + return path_type::directory; + case std::filesystem::file_type::symlink: + case std::filesystem::file_type::block: + case std::filesystem::file_type::character: + case std::filesystem::file_type::fifo: + case std::filesystem::file_type::socket: + case std::filesystem::file_type::regular: + case std::filesystem::file_type::unknown: + default: + return path_type::file; + } +} +#else +/// get the type of the path from a file name +inline path_type check_path(const char *file) noexcept { +#if defined(_MSC_VER) + struct __stat64 buffer; + if(_stat64(file, &buffer) == 0) { + return ((buffer.st_mode & S_IFDIR) != 0) ? path_type::directory : path_type::file; + } +#else + struct stat buffer; + if(stat(file, &buffer) == 0) { + return ((buffer.st_mode & S_IFDIR) != 0) ? path_type::directory : path_type::file; + } +#endif + return path_type::nonexistent; +} +#endif +/// Check for an existing file (returns error message if check fails) +class ExistingFileValidator : public Validator { + public: + ExistingFileValidator() : Validator("FILE") { + func_ = [](std::string &filename) { + auto path_result = check_path(filename.c_str()); + if(path_result == path_type::nonexistent) { + return "File does not exist: " + filename; + } + if(path_result == path_type::directory) { + return "File is actually a directory: " + filename; + } + return std::string(); + }; + } +}; + +/// Check for an existing directory (returns error message if check fails) +class ExistingDirectoryValidator : public Validator { + public: + ExistingDirectoryValidator() : Validator("DIR") { + func_ = [](std::string &filename) { + auto path_result = check_path(filename.c_str()); + if(path_result == path_type::nonexistent) { + return "Directory does not exist: " + filename; + } + if(path_result == path_type::file) { + return "Directory is actually a file: " + filename; + } + return std::string(); + }; + } +}; + +/// Check for an existing path +class ExistingPathValidator : public Validator { + public: + ExistingPathValidator() : Validator("PATH(existing)") { + func_ = [](std::string &filename) { + auto path_result = check_path(filename.c_str()); + if(path_result == path_type::nonexistent) { + return "Path does not exist: " + filename; + } + return std::string(); + }; + } +}; + +/// Check for an non-existing path +class NonexistentPathValidator : public Validator { + public: + NonexistentPathValidator() : Validator("PATH(non-existing)") { + func_ = [](std::string &filename) { + auto path_result = check_path(filename.c_str()); + if(path_result != path_type::nonexistent) { + return "Path already exists: " + filename; + } + return std::string(); + }; + } +}; + +/// Validate the given string is a legal ipv4 address +class IPV4Validator : public Validator { + public: + IPV4Validator() : Validator("IPV4") { + func_ = [](std::string &ip_addr) { + auto result = CLI::detail::split(ip_addr, '.'); + if(result.size() != 4) { + return std::string("Invalid IPV4 address must have four parts (") + ip_addr + ')'; + } + int num; + for(const auto &var : result) { + bool retval = detail::lexical_cast(var, num); + if(!retval) { + return std::string("Failed parsing number (") + var + ')'; + } + if(num < 0 || num > 255) { + return std::string("Each IP number must be between 0 and 255 ") + var; + } + } + return std::string(); + }; + } +}; + +} // namespace detail + +// Static is not needed here, because global const implies static. + +/// Check for existing file (returns error message if check fails) +const detail::ExistingFileValidator ExistingFile; + +/// Check for an existing directory (returns error message if check fails) +const detail::ExistingDirectoryValidator ExistingDirectory; + +/// Check for an existing path +const detail::ExistingPathValidator ExistingPath; + +/// Check for an non-existing path +const detail::NonexistentPathValidator NonexistentPath; + +/// Check for an IP4 address +const detail::IPV4Validator ValidIPV4; + +/// Validate the input as a particular type +template class TypeValidator : public Validator { + public: + explicit TypeValidator(const std::string &validator_name) : Validator(validator_name) { + func_ = [](std::string &input_string) { + auto val = DesiredType(); + if(!detail::lexical_cast(input_string, val)) { + return std::string("Failed parsing ") + input_string + " as a " + detail::type_name(); + } + return std::string(); + }; + } + TypeValidator() : TypeValidator(detail::type_name()) {} +}; + +/// Check for a number +const TypeValidator Number("NUMBER"); + +/// Produce a range (factory). Min and max are inclusive. +class Range : public Validator { + public: + /// This produces a range with min and max inclusive. + /// + /// Note that the constructor is templated, but the struct is not, so C++17 is not + /// needed to provide nice syntax for Range(a,b). + template + Range(T min_val, T max_val, const std::string &validator_name = std::string{}) : Validator(validator_name) { + if(validator_name.empty()) { + std::stringstream out; + out << detail::type_name() << " in [" << min_val << " - " << max_val << "]"; + description(out.str()); + } + + func_ = [min_val, max_val](std::string &input) { + T val; + bool converted = detail::lexical_cast(input, val); + if((!converted) || (val < min_val || val > max_val)) + return std::string("Value ") + input + " not in range " + std::to_string(min_val) + " to " + + std::to_string(max_val); + + return std::string{}; + }; + } + + /// Range of one value is 0 to value + template + explicit Range(T max_val, const std::string &validator_name = std::string{}) + : Range(static_cast(0), max_val, validator_name) {} +}; + +/// Check for a non negative number +const Range NonNegativeNumber((std::numeric_limits::max)(), "NONNEGATIVE"); + +/// Check for a positive valued number (val>0.0), min() her is the smallest positive number +const Range PositiveNumber((std::numeric_limits::min)(), (std::numeric_limits::max)(), "POSITIVE"); + +/// Produce a bounded range (factory). Min and max are inclusive. +class Bound : public Validator { + public: + /// This bounds a value with min and max inclusive. + /// + /// Note that the constructor is templated, but the struct is not, so C++17 is not + /// needed to provide nice syntax for Range(a,b). + template Bound(T min_val, T max_val) { + std::stringstream out; + out << detail::type_name() << " bounded to [" << min_val << " - " << max_val << "]"; + description(out.str()); + + func_ = [min_val, max_val](std::string &input) { + T val; + bool converted = detail::lexical_cast(input, val); + if(!converted) { + return std::string("Value ") + input + " could not be converted"; + } + if(val < min_val) + input = detail::to_string(min_val); + else if(val > max_val) + input = detail::to_string(max_val); + + return std::string{}; + }; + } + + /// Range of one value is 0 to value + template explicit Bound(T max_val) : Bound(static_cast(0), max_val) {} +}; + +namespace detail { +template ::type>::value, detail::enabler> = detail::dummy> +auto smart_deref(T value) -> decltype(*value) { + return *value; +} + +template < + typename T, + enable_if_t::type>::value, detail::enabler> = detail::dummy> +typename std::remove_reference::type &smart_deref(T &value) { + return value; +} +/// Generate a string representation of a set +template std::string generate_set(const T &set) { + using element_t = typename detail::element_type::type; + using iteration_type_t = typename detail::pair_adaptor::value_type; // the type of the object pair + std::string out(1, '{'); + out.append(detail::join( + detail::smart_deref(set), + [](const iteration_type_t &v) { return detail::pair_adaptor::first(v); }, + ",")); + out.push_back('}'); + return out; +} + +/// Generate a string representation of a map +template std::string generate_map(const T &map, bool key_only = false) { + using element_t = typename detail::element_type::type; + using iteration_type_t = typename detail::pair_adaptor::value_type; // the type of the object pair + std::string out(1, '{'); + out.append(detail::join( + detail::smart_deref(map), + [key_only](const iteration_type_t &v) { + std::string res{detail::to_string(detail::pair_adaptor::first(v))}; + + if(!key_only) { + res.append("->"); + res += detail::to_string(detail::pair_adaptor::second(v)); + } + return res; + }, + ",")); + out.push_back('}'); + return out; +} + +template struct has_find { + template + static auto test(int) -> decltype(std::declval().find(std::declval()), std::true_type()); + template static auto test(...) -> decltype(std::false_type()); + + static const auto value = decltype(test(0))::value; + using type = std::integral_constant; +}; + +/// A search function +template ::value, detail::enabler> = detail::dummy> +auto search(const T &set, const V &val) -> std::pair { + using element_t = typename detail::element_type::type; + auto &setref = detail::smart_deref(set); + auto it = std::find_if(std::begin(setref), std::end(setref), [&val](decltype(*std::begin(setref)) v) { + return (detail::pair_adaptor::first(v) == val); + }); + return {(it != std::end(setref)), it}; +} + +/// A search function that uses the built in find function +template ::value, detail::enabler> = detail::dummy> +auto search(const T &set, const V &val) -> std::pair { + auto &setref = detail::smart_deref(set); + auto it = setref.find(val); + return {(it != std::end(setref)), it}; +} + +/// A search function with a filter function +template +auto search(const T &set, const V &val, const std::function &filter_function) + -> std::pair { + using element_t = typename detail::element_type::type; + // do the potentially faster first search + auto res = search(set, val); + if((res.first) || (!(filter_function))) { + return res; + } + // if we haven't found it do the longer linear search with all the element translations + auto &setref = detail::smart_deref(set); + auto it = std::find_if(std::begin(setref), std::end(setref), [&](decltype(*std::begin(setref)) v) { + V a{detail::pair_adaptor::first(v)}; + a = filter_function(a); + return (a == val); + }); + return {(it != std::end(setref)), it}; +} + +// the following suggestion was made by Nikita Ofitserov(@himikof) +// done in templates to prevent compiler warnings on negation of unsigned numbers + +/// Do a check for overflow on signed numbers +template +inline typename std::enable_if::value, T>::type overflowCheck(const T &a, const T &b) { + if((a > 0) == (b > 0)) { + return ((std::numeric_limits::max)() / (std::abs)(a) < (std::abs)(b)); + } else { + return ((std::numeric_limits::min)() / (std::abs)(a) > -(std::abs)(b)); + } +} +/// Do a check for overflow on unsigned numbers +template +inline typename std::enable_if::value, T>::type overflowCheck(const T &a, const T &b) { + return ((std::numeric_limits::max)() / a < b); +} + +/// Performs a *= b; if it doesn't cause integer overflow. Returns false otherwise. +template typename std::enable_if::value, bool>::type checked_multiply(T &a, T b) { + if(a == 0 || b == 0 || a == 1 || b == 1) { + a *= b; + return true; + } + if(a == (std::numeric_limits::min)() || b == (std::numeric_limits::min)()) { + return false; + } + if(overflowCheck(a, b)) { + return false; + } + a *= b; + return true; +} + +/// Performs a *= b; if it doesn't equal infinity. Returns false otherwise. +template +typename std::enable_if::value, bool>::type checked_multiply(T &a, T b) { + T c = a * b; + if(std::isinf(c) && !std::isinf(a) && !std::isinf(b)) { + return false; + } + a = c; + return true; +} + +} // namespace detail +/// Verify items are in a set +class IsMember : public Validator { + public: + using filter_fn_t = std::function; + + /// This allows in-place construction using an initializer list + template + IsMember(std::initializer_list values, Args &&...args) + : IsMember(std::vector(values), std::forward(args)...) {} + + /// This checks to see if an item is in a set (empty function) + template explicit IsMember(T &&set) : IsMember(std::forward(set), nullptr) {} + + /// This checks to see if an item is in a set: pointer or copy version. You can pass in a function that will filter + /// both sides of the comparison before computing the comparison. + template explicit IsMember(T set, F filter_function) { + + // Get the type of the contained item - requires a container have ::value_type + // if the type does not have first_type and second_type, these are both value_type + using element_t = typename detail::element_type::type; // Removes (smart) pointers if needed + using item_t = typename detail::pair_adaptor::first_type; // Is value_type if not a map + + using local_item_t = typename IsMemberType::type; // This will convert bad types to good ones + // (const char * to std::string) + + // Make a local copy of the filter function, using a std::function if not one already + std::function filter_fn = filter_function; + + // This is the type name for help, it will take the current version of the set contents + desc_function_ = [set]() { return detail::generate_set(detail::smart_deref(set)); }; + + // This is the function that validates + // It stores a copy of the set pointer-like, so shared_ptr will stay alive + func_ = [set, filter_fn](std::string &input) { + local_item_t b; + if(!detail::lexical_cast(input, b)) { + throw ValidationError(input); // name is added later + } + if(filter_fn) { + b = filter_fn(b); + } + auto res = detail::search(set, b, filter_fn); + if(res.first) { + // Make sure the version in the input string is identical to the one in the set + if(filter_fn) { + input = detail::value_string(detail::pair_adaptor::first(*(res.second))); + } + + // Return empty error string (success) + return std::string{}; + } + + // If you reach this point, the result was not found + return input + " not in " + detail::generate_set(detail::smart_deref(set)); + }; + } + + /// You can pass in as many filter functions as you like, they nest (string only currently) + template + IsMember(T &&set, filter_fn_t filter_fn_1, filter_fn_t filter_fn_2, Args &&...other) + : IsMember( + std::forward(set), + [filter_fn_1, filter_fn_2](std::string a) { return filter_fn_2(filter_fn_1(a)); }, + other...) {} +}; + +/// definition of the default transformation object +template using TransformPairs = std::vector>; + +/// Translate named items to other or a value set +class Transformer : public Validator { + public: + using filter_fn_t = std::function; + + /// This allows in-place construction + template + Transformer(std::initializer_list> values, Args &&...args) + : Transformer(TransformPairs(values), std::forward(args)...) {} + + /// direct map of std::string to std::string + template explicit Transformer(T &&mapping) : Transformer(std::forward(mapping), nullptr) {} + + /// This checks to see if an item is in a set: pointer or copy version. You can pass in a function that will filter + /// both sides of the comparison before computing the comparison. + template explicit Transformer(T mapping, F filter_function) { + + static_assert(detail::pair_adaptor::type>::value, + "mapping must produce value pairs"); + // Get the type of the contained item - requires a container have ::value_type + // if the type does not have first_type and second_type, these are both value_type + using element_t = typename detail::element_type::type; // Removes (smart) pointers if needed + using item_t = typename detail::pair_adaptor::first_type; // Is value_type if not a map + using local_item_t = typename IsMemberType::type; // Will convert bad types to good ones + // (const char * to std::string) + + // Make a local copy of the filter function, using a std::function if not one already + std::function filter_fn = filter_function; + + // This is the type name for help, it will take the current version of the set contents + desc_function_ = [mapping]() { return detail::generate_map(detail::smart_deref(mapping)); }; + + func_ = [mapping, filter_fn](std::string &input) { + local_item_t b; + if(!detail::lexical_cast(input, b)) { + return std::string(); + // there is no possible way we can match anything in the mapping if we can't convert so just return + } + if(filter_fn) { + b = filter_fn(b); + } + auto res = detail::search(mapping, b, filter_fn); + if(res.first) { + input = detail::value_string(detail::pair_adaptor::second(*res.second)); + } + return std::string{}; + }; + } + + /// You can pass in as many filter functions as you like, they nest + template + Transformer(T &&mapping, filter_fn_t filter_fn_1, filter_fn_t filter_fn_2, Args &&...other) + : Transformer( + std::forward(mapping), + [filter_fn_1, filter_fn_2](std::string a) { return filter_fn_2(filter_fn_1(a)); }, + other...) {} +}; + +/// translate named items to other or a value set +class CheckedTransformer : public Validator { + public: + using filter_fn_t = std::function; + + /// This allows in-place construction + template + CheckedTransformer(std::initializer_list> values, Args &&...args) + : CheckedTransformer(TransformPairs(values), std::forward(args)...) {} + + /// direct map of std::string to std::string + template explicit CheckedTransformer(T mapping) : CheckedTransformer(std::move(mapping), nullptr) {} + + /// This checks to see if an item is in a set: pointer or copy version. You can pass in a function that will filter + /// both sides of the comparison before computing the comparison. + template explicit CheckedTransformer(T mapping, F filter_function) { + + static_assert(detail::pair_adaptor::type>::value, + "mapping must produce value pairs"); + // Get the type of the contained item - requires a container have ::value_type + // if the type does not have first_type and second_type, these are both value_type + using element_t = typename detail::element_type::type; // Removes (smart) pointers if needed + using item_t = typename detail::pair_adaptor::first_type; // Is value_type if not a map + using local_item_t = typename IsMemberType::type; // Will convert bad types to good ones + // (const char * to std::string) + using iteration_type_t = typename detail::pair_adaptor::value_type; // the type of the object pair + + // Make a local copy of the filter function, using a std::function if not one already + std::function filter_fn = filter_function; + + auto tfunc = [mapping]() { + std::string out("value in "); + out += detail::generate_map(detail::smart_deref(mapping)) + " OR {"; + out += detail::join( + detail::smart_deref(mapping), + [](const iteration_type_t &v) { return detail::to_string(detail::pair_adaptor::second(v)); }, + ","); + out.push_back('}'); + return out; + }; + + desc_function_ = tfunc; + + func_ = [mapping, tfunc, filter_fn](std::string &input) { + local_item_t b; + bool converted = detail::lexical_cast(input, b); + if(converted) { + if(filter_fn) { + b = filter_fn(b); + } + auto res = detail::search(mapping, b, filter_fn); + if(res.first) { + input = detail::value_string(detail::pair_adaptor::second(*res.second)); + return std::string{}; + } + } + for(const auto &v : detail::smart_deref(mapping)) { + auto output_string = detail::value_string(detail::pair_adaptor::second(v)); + if(output_string == input) { + return std::string(); + } + } + + return "Check " + input + " " + tfunc() + " FAILED"; + }; + } + + /// You can pass in as many filter functions as you like, they nest + template + CheckedTransformer(T &&mapping, filter_fn_t filter_fn_1, filter_fn_t filter_fn_2, Args &&...other) + : CheckedTransformer( + std::forward(mapping), + [filter_fn_1, filter_fn_2](std::string a) { return filter_fn_2(filter_fn_1(a)); }, + other...) {} +}; + +/// Helper function to allow ignore_case to be passed to IsMember or Transform +inline std::string ignore_case(std::string item) { return detail::to_lower(item); } + +/// Helper function to allow ignore_underscore to be passed to IsMember or Transform +inline std::string ignore_underscore(std::string item) { return detail::remove_underscore(item); } + +/// Helper function to allow checks to ignore spaces to be passed to IsMember or Transform +inline std::string ignore_space(std::string item) { + item.erase(std::remove(std::begin(item), std::end(item), ' '), std::end(item)); + item.erase(std::remove(std::begin(item), std::end(item), '\t'), std::end(item)); + return item; +} + +/// Multiply a number by a factor using given mapping. +/// Can be used to write transforms for SIZE or DURATION inputs. +/// +/// Example: +/// With mapping = `{"b"->1, "kb"->1024, "mb"->1024*1024}` +/// one can recognize inputs like "100", "12kb", "100 MB", +/// that will be automatically transformed to 100, 14448, 104857600. +/// +/// Output number type matches the type in the provided mapping. +/// Therefore, if it is required to interpret real inputs like "0.42 s", +/// the mapping should be of a type or . +class AsNumberWithUnit : public Validator { + public: + /// Adjust AsNumberWithUnit behavior. + /// CASE_SENSITIVE/CASE_INSENSITIVE controls how units are matched. + /// UNIT_OPTIONAL/UNIT_REQUIRED throws ValidationError + /// if UNIT_REQUIRED is set and unit literal is not found. + enum Options { + CASE_SENSITIVE = 0, + CASE_INSENSITIVE = 1, + UNIT_OPTIONAL = 0, + UNIT_REQUIRED = 2, + DEFAULT = CASE_INSENSITIVE | UNIT_OPTIONAL + }; + + template + explicit AsNumberWithUnit(std::map mapping, + Options opts = DEFAULT, + const std::string &unit_name = "UNIT") { + description(generate_description(unit_name, opts)); + validate_mapping(mapping, opts); + + // transform function + func_ = [mapping, opts](std::string &input) -> std::string { + Number num; + + detail::rtrim(input); + if(input.empty()) { + throw ValidationError("Input is empty"); + } + + // Find split position between number and prefix + auto unit_begin = input.end(); + while(unit_begin > input.begin() && std::isalpha(*(unit_begin - 1), std::locale())) { + --unit_begin; + } + + std::string unit{unit_begin, input.end()}; + input.resize(static_cast(std::distance(input.begin(), unit_begin))); + detail::trim(input); + + if(opts & UNIT_REQUIRED && unit.empty()) { + throw ValidationError("Missing mandatory unit"); + } + if(opts & CASE_INSENSITIVE) { + unit = detail::to_lower(unit); + } + if(unit.empty()) { + if(!detail::lexical_cast(input, num)) { + throw ValidationError(std::string("Value ") + input + " could not be converted to " + + detail::type_name()); + } + // No need to modify input if no unit passed + return {}; + } + + // find corresponding factor + auto it = mapping.find(unit); + if(it == mapping.end()) { + throw ValidationError(unit + + " unit not recognized. " + "Allowed values: " + + detail::generate_map(mapping, true)); + } + + if(!input.empty()) { + bool converted = detail::lexical_cast(input, num); + if(!converted) { + throw ValidationError(std::string("Value ") + input + " could not be converted to " + + detail::type_name()); + } + // perform safe multiplication + bool ok = detail::checked_multiply(num, it->second); + if(!ok) { + throw ValidationError(detail::to_string(num) + " multiplied by " + unit + + " factor would cause number overflow. Use smaller value."); + } + } else { + num = static_cast(it->second); + } + + input = detail::to_string(num); + + return {}; + }; + } + + private: + /// Check that mapping contains valid units. + /// Update mapping for CASE_INSENSITIVE mode. + template static void validate_mapping(std::map &mapping, Options opts) { + for(auto &kv : mapping) { + if(kv.first.empty()) { + throw ValidationError("Unit must not be empty."); + } + if(!detail::isalpha(kv.first)) { + throw ValidationError("Unit must contain only letters."); + } + } + + // make all units lowercase if CASE_INSENSITIVE + if(opts & CASE_INSENSITIVE) { + std::map lower_mapping; + for(auto &kv : mapping) { + auto s = detail::to_lower(kv.first); + if(lower_mapping.count(s)) { + throw ValidationError(std::string("Several matching lowercase unit representations are found: ") + + s); + } + lower_mapping[detail::to_lower(kv.first)] = kv.second; + } + mapping = std::move(lower_mapping); + } + } + + /// Generate description like this: NUMBER [UNIT] + template static std::string generate_description(const std::string &name, Options opts) { + std::stringstream out; + out << detail::type_name() << ' '; + if(opts & UNIT_REQUIRED) { + out << name; + } else { + out << '[' << name << ']'; + } + return out.str(); + } +}; + +/// Converts a human-readable size string (with unit literal) to uin64_t size. +/// Example: +/// "100" => 100 +/// "1 b" => 100 +/// "10Kb" => 10240 // you can configure this to be interpreted as kilobyte (*1000) or kibibyte (*1024) +/// "10 KB" => 10240 +/// "10 kb" => 10240 +/// "10 kib" => 10240 // *i, *ib are always interpreted as *bibyte (*1024) +/// "10kb" => 10240 +/// "2 MB" => 2097152 +/// "2 EiB" => 2^61 // Units up to exibyte are supported +class AsSizeValue : public AsNumberWithUnit { + public: + using result_t = std::uint64_t; + + /// If kb_is_1000 is true, + /// interpret 'kb', 'k' as 1000 and 'kib', 'ki' as 1024 + /// (same applies to higher order units as well). + /// Otherwise, interpret all literals as factors of 1024. + /// The first option is formally correct, but + /// the second interpretation is more wide-spread + /// (see https://en.wikipedia.org/wiki/Binary_prefix). + explicit AsSizeValue(bool kb_is_1000) : AsNumberWithUnit(get_mapping(kb_is_1000)) { + if(kb_is_1000) { + description("SIZE [b, kb(=1000b), kib(=1024b), ...]"); + } else { + description("SIZE [b, kb(=1024b), ...]"); + } + } + + private: + /// Get mapping + static std::map init_mapping(bool kb_is_1000) { + std::map m; + result_t k_factor = kb_is_1000 ? 1000 : 1024; + result_t ki_factor = 1024; + result_t k = 1; + result_t ki = 1; + m["b"] = 1; + for(std::string p : {"k", "m", "g", "t", "p", "e"}) { + k *= k_factor; + ki *= ki_factor; + m[p] = k; + m[p + "b"] = k; + m[p + "i"] = ki; + m[p + "ib"] = ki; + } + return m; + } + + /// Cache calculated mapping + static std::map get_mapping(bool kb_is_1000) { + if(kb_is_1000) { + static auto m = init_mapping(true); + return m; + } else { + static auto m = init_mapping(false); + return m; + } + } +}; + +namespace detail { +/// Split a string into a program name and command line arguments +/// the string is assumed to contain a file name followed by other arguments +/// the return value contains is a pair with the first argument containing the program name and the second +/// everything else. +inline std::pair split_program_name(std::string commandline) { + // try to determine the programName + std::pair vals; + trim(commandline); + auto esp = commandline.find_first_of(' ', 1); + while(detail::check_path(commandline.substr(0, esp).c_str()) != path_type::file) { + esp = commandline.find_first_of(' ', esp + 1); + if(esp == std::string::npos) { + // if we have reached the end and haven't found a valid file just assume the first argument is the + // program name + if(commandline[0] == '"' || commandline[0] == '\'' || commandline[0] == '`') { + bool embeddedQuote = false; + auto keyChar = commandline[0]; + auto end = commandline.find_first_of(keyChar, 1); + while((end != std::string::npos) && (commandline[end - 1] == '\\')) { // deal with escaped quotes + end = commandline.find_first_of(keyChar, end + 1); + embeddedQuote = true; + } + if(end != std::string::npos) { + vals.first = commandline.substr(1, end - 1); + esp = end + 1; + if(embeddedQuote) { + vals.first = find_and_replace(vals.first, std::string("\\") + keyChar, std::string(1, keyChar)); + } + } else { + esp = commandline.find_first_of(' ', 1); + } + } else { + esp = commandline.find_first_of(' ', 1); + } + + break; + } + } + if(vals.first.empty()) { + vals.first = commandline.substr(0, esp); + rtrim(vals.first); + } + + // strip the program name + vals.second = (esp != std::string::npos) ? commandline.substr(esp + 1) : std::string{}; + ltrim(vals.second); + return vals; +} + +} // namespace detail +/// @} + + + + +class Option; +class App; + +/// This enum signifies the type of help requested +/// +/// This is passed in by App; all user classes must accept this as +/// the second argument. + +enum class AppFormatMode { + Normal, ///< The normal, detailed help + All, ///< A fully expanded help + Sub, ///< Used when printed as part of expanded subcommand +}; + +/// This is the minimum requirements to run a formatter. +/// +/// A user can subclass this is if they do not care at all +/// about the structure in CLI::Formatter. +class FormatterBase { + protected: + /// @name Options + ///@{ + + /// The width of the first column + std::size_t column_width_{30}; + + /// @brief The required help printout labels (user changeable) + /// Values are Needs, Excludes, etc. + std::map labels_{}; + + ///@} + /// @name Basic + ///@{ + + public: + FormatterBase() = default; + FormatterBase(const FormatterBase &) = default; + FormatterBase(FormatterBase &&) = default; + + /// Adding a destructor in this form to work around bug in GCC 4.7 + virtual ~FormatterBase() noexcept {} // NOLINT(modernize-use-equals-default) + + /// This is the key method that puts together help + virtual std::string make_help(const App *, std::string, AppFormatMode) const = 0; + + ///@} + /// @name Setters + ///@{ + + /// Set the "REQUIRED" label + void label(std::string key, std::string val) { labels_[key] = val; } + + /// Set the column width + void column_width(std::size_t val) { column_width_ = val; } + + ///@} + /// @name Getters + ///@{ + + /// Get the current value of a name (REQUIRED, etc.) + std::string get_label(std::string key) const { + if(labels_.find(key) == labels_.end()) + return key; + else + return labels_.at(key); + } + + /// Get the current column width + std::size_t get_column_width() const { return column_width_; } + + ///@} +}; + +/// This is a specialty override for lambda functions +class FormatterLambda final : public FormatterBase { + using funct_t = std::function; + + /// The lambda to hold and run + funct_t lambda_; + + public: + /// Create a FormatterLambda with a lambda function + explicit FormatterLambda(funct_t funct) : lambda_(std::move(funct)) {} + + /// Adding a destructor (mostly to make GCC 4.7 happy) + ~FormatterLambda() noexcept override {} // NOLINT(modernize-use-equals-default) + + /// This will simply call the lambda function + std::string make_help(const App *app, std::string name, AppFormatMode mode) const override { + return lambda_(app, name, mode); + } +}; + +/// This is the default Formatter for CLI11. It pretty prints help output, and is broken into quite a few +/// overridable methods, to be highly customizable with minimal effort. +class Formatter : public FormatterBase { + public: + Formatter() = default; + Formatter(const Formatter &) = default; + Formatter(Formatter &&) = default; + + /// @name Overridables + ///@{ + + /// This prints out a group of options with title + /// + virtual std::string make_group(std::string group, bool is_positional, std::vector opts) const; + + /// This prints out just the positionals "group" + virtual std::string make_positionals(const App *app) const; + + /// This prints out all the groups of options + std::string make_groups(const App *app, AppFormatMode mode) const; + + /// This prints out all the subcommands + virtual std::string make_subcommands(const App *app, AppFormatMode mode) const; + + /// This prints out a subcommand + virtual std::string make_subcommand(const App *sub) const; + + /// This prints out a subcommand in help-all + virtual std::string make_expanded(const App *sub) const; + + /// This prints out all the groups of options + virtual std::string make_footer(const App *app) const; + + /// This displays the description line + virtual std::string make_description(const App *app) const; + + /// This displays the usage line + virtual std::string make_usage(const App *app, std::string name) const; + + /// This puts everything together + std::string make_help(const App * /*app*/, std::string, AppFormatMode) const override; + + ///@} + /// @name Options + ///@{ + + /// This prints out an option help line, either positional or optional form + virtual std::string make_option(const Option *opt, bool is_positional) const { + std::stringstream out; + detail::format_help( + out, make_option_name(opt, is_positional) + make_option_opts(opt), make_option_desc(opt), column_width_); + return out.str(); + } + + /// @brief This is the name part of an option, Default: left column + virtual std::string make_option_name(const Option *, bool) const; + + /// @brief This is the options part of the name, Default: combined into left column + virtual std::string make_option_opts(const Option *) const; + + /// @brief This is the description. Default: Right column, on new line if left column too large + virtual std::string make_option_desc(const Option *) const; + + /// @brief This is used to print the name on the USAGE line + virtual std::string make_option_usage(const Option *opt) const; + + ///@} +}; + + + + +using results_t = std::vector; +/// callback function definition +using callback_t = std::function; + +class Option; +class App; + +using Option_p = std::unique_ptr