Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace hard-wired DFT functions with IDFT #136

Merged
merged 46 commits into from
Dec 10, 2021
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
c7ba58f
Initial draft of a DFT interface and FFTW implementation
brettviren Nov 1, 2021
59be25c
Tell Boost to shut up with the internal deprecation warnings
brettviren Nov 2, 2021
a9787bb
Tell Boost to shut up with the internal deprecation warnings
brettviren Nov 2, 2021
08815ba
Add stack trace to exception what()
brettviren Nov 2, 2021
a7b6f97
Throw instead of returning garbage on garbage input.
brettviren Nov 2, 2021
dd4dffb
Start on higher-level dft functions
brettviren Nov 2, 2021
dbcbe9b
Allow to optionally provide data and metadata in constructor
brettviren Nov 5, 2021
2ffbb84
More work toward DFT as a service
brettviren Nov 5, 2021
c248bab
Elaborate on comments
brettviren Nov 8, 2021
cd3369d
Flesh out and test DftTools interface
brettviren Nov 8, 2021
a83cba8
Add axis, transpose, tests
brettviren Nov 15, 2021
e8e0625
Improve tests
brettviren Nov 16, 2021
28d9ec9
More testing
brettviren Nov 16, 2021
c798cd4
Make a semaphore interface and implement with what is in util
brettviren Nov 16, 2021
fc8dc58
Move common code to a 'context' mix-in, add initial torch imp of IDFT
brettviren Nov 16, 2021
22d5ea0
Work out brain bugs in understanding torch tensor storage
brettviren Nov 17, 2021
47e67dd
Typo in plugin name fixed
brettviren Nov 18, 2021
d90bdcc
Initial draft of an IDFT benchmarker
brettviren Nov 18, 2021
a46fd9a
Improve timing measurements
brettviren Nov 18, 2021
3c65ffb
Rename, too slow to run each time
brettviren Nov 18, 2021
365c21c
Make this more globally accessible so test/check programs can use it
brettviren Nov 19, 2021
359b50f
Improve benchmark
brettviren Nov 19, 2021
bbe851d
Fix regression in cli arg parsing
brettviren Nov 22, 2021
0a49ef6
Remove obsolete 'zipper' based sim
brettviren Nov 22, 2021
2bedce6
Pass IDFT::pointer by const ref
brettviren Nov 22, 2021
ee42c72
Port gen to IDFT
brettviren Nov 22, 2021
c499a53
First draft removal of hard-wired DFT in favor of IDFT.
brettviren Nov 23, 2021
927d8ef
Fix lack of storing device name.
brettviren Nov 29, 2021
fedd276
Fix the ignored base class config methods by removing/moving the base…
brettviren Nov 29, 2021
28af349
Measure first exec+plan separate from subsequent execs
brettviren Nov 29, 2021
3411d42
Add single precision complex float
brettviren Nov 30, 2021
f31dcaa
A start on a correctness test for IDFT
brettviren Nov 30, 2021
b774679
Add r2c and c2r helpers
brettviren Dec 1, 2021
e89f0c9
Replace cli11 with boost po and get checker fleshed out
brettviren Dec 1, 2021
df869f3
Fix row/col-wise monkeys
brettviren Dec 1, 2021
8160700
More eigen row/col-wise testing
brettviren Dec 1, 2021
db16d54
Fix one wrong conversion to IDFT, improve logging along the way to fi…
brettviren Dec 2, 2021
07d8a3d
Remove noisy
brettviren Dec 2, 2021
a9e7eb8
Fix typo in dispatch
brettviren Dec 2, 2021
5cf456b
Allow passing dft config down
brettviren Dec 2, 2021
f4f02c4
More passing down of IDFT
brettviren Dec 2, 2021
f581274
Add DFT throughout config.
brettviren Dec 3, 2021
1f282d5
Remove previously commented-out old-style dft()/idft() calls
brettviren Dec 3, 2021
2c13e3c
Add construction of IDFT for tests
brettviren Dec 3, 2021
d8ddb6e
Remove include of removed header
brettviren Dec 8, 2021
847fb83
Few fixes found by Haiwang in review
brettviren Dec 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions aux/inc/WireCellAux/DftTools.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/**
This provides std::vector and Eigen::Array typed interface to an
IDFT.
*/

#ifndef WIRECELL_AUX_DFTTOOLS
#define WIRECELL_AUX_DFTTOOLS

#include "WireCellIface/IDFT.h"
#include <vector>
#include <Eigen/Core>

namespace WireCell::Aux {

using complex_t = IDFT::complex_t;

// std::vector based functions

using real_vector_t = std::vector<float>;
using complex_vector_t = std::vector<complex_t>;

// 1D with vectors

// Perform forward c2c transform on vector.
inline complex_vector_t fwd(const IDFT::pointer& dft, const complex_vector_t& seq)
{
complex_vector_t ret(seq.size());
dft->fwd1d(seq.data(), ret.data(), ret.size());
return ret;
}

// Perform forward r2c transform on vector.
inline complex_vector_t fwd_r2c(const IDFT::pointer& dft, const real_vector_t& vec)
{
complex_vector_t cvec(vec.size());
std::transform(vec.begin(), vec.end(), cvec.begin(),
[](float re) { return Aux::complex_t(re,0.0); } );
return fwd(dft, cvec);
}

// Perform inverse c2c transform on vector.
inline complex_vector_t inv(const IDFT::pointer& dft, const complex_vector_t& spec)
{
complex_vector_t ret(spec.size());
dft->inv1d(spec.data(), ret.data(), ret.size());
return ret;
}

// Perform inverse c2r transform on vector.
inline real_vector_t inv_c2r(const IDFT::pointer& dft, const complex_vector_t& spec)
{
auto cvec = inv(dft, spec);
real_vector_t rvec(cvec.size());
std::transform(cvec.begin(), cvec.end(), rvec.begin(),
[](const Aux::complex_t& c) { return std::real(c); });
return rvec;
}

// 1D high-level interface

/// Convovle in1 and in2. Returned vecgtor has size sum of sizes
/// of in1 and in2 less one element in order to assure no periodic
/// aliasing. Caller need not (should not) pad either input.
/// Caller is free to truncate result as required.
real_vector_t convolve(const IDFT::pointer& dft,
const real_vector_t& in1,
const real_vector_t& in2);


/// Replace response res1 in meas with response res2.
///
/// This will compute the FFT of all three, in frequency space will form:
///
/// meas * resp2 / resp1
///
/// apply the inverse FFT and return its real part.
///
/// The output vector is long enough to assure no periodic
/// aliasing. In general, caller should NOT pre-pad any input.
/// Any subsequent truncation of result is up to caller.
real_vector_t replace(const IDFT::pointer& dft,
const real_vector_t& meas,
const real_vector_t& res1,
const real_vector_t& res2);


// Eigen array based functions

/// 2D array types. Note, use Array::cast<complex_t>() if you
/// need to convert rom real or arr.real() to convert to real.
using real_array_t = Eigen::ArrayXXf;
using complex_array_t = Eigen::ArrayXXcf;

// 2D with Eigen arrays. Use eg arr.cast<complex_>() to provde
// from real or arr.real()() to convert result to real.

// Transform both dimesions.
complex_array_t fwd(const IDFT::pointer& dft, const complex_array_t& arr);
complex_array_t inv(const IDFT::pointer& dft, const complex_array_t& arr);

// As above but internally convert input or output. These are
// just syntactic sugar hiding a .cast<complex_t>() or a .real()
// call.
complex_array_t fwd_r2c(const IDFT::pointer& dft, const real_array_t& arr);
real_array_t inv_c2r(const IDFT::pointer& dft, const complex_array_t& arr);

// Transform a 2D array along one axis.
//
// The axis identifies the logical array "dimension" over which
// the transform is applied. For example, axis=1 means the
// transforms are applied along columns (ie, on a per-row basis).
// Note: this is the same convention as held by numpy.fft.
//
// The axis is interpreted in the "logical" sense Eigen arrays
// indexed as array(irow, icol). Ie, the dimension traversing
// rows is axis 0 and the dimension traversing columns is axis 1.
// Note: internal storage order of an Eigen array may differ from
// the logical order and indeed that of the array template type
// order. Neither is pertinent in setting the axis.
complex_array_t fwd(const IDFT::pointer& dft, const complex_array_t& arr, int axis);
complex_array_t inv(const IDFT::pointer& dft, const complex_array_t& arr, int axis);


// Fixme: possible additions
// - superposition of 2 reals for 2x speedup
// - r2c / c2r for 1b

}

#endif
57 changes: 57 additions & 0 deletions aux/inc/WireCellAux/FftwDFT.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#ifndef WIRECELLAUX_FFTWDFT
#define WIRECELLAUX_FFTWDFT

#include "WireCellIface/IDFT.h"

namespace WireCell::Aux {

/**
The FftwDFT component provides IDFT based on FFTW3.

All instances share a common thread-safe plan cache. There is
no benefit to using more than one instance in a process.

See IDFT.h for important comments.
*/
class FftwDFT : public IDFT {
public:

FftwDFT();
virtual ~FftwDFT();

// 1d

virtual
void fwd1d(const complex_t* in, complex_t* out,
int size) const;

virtual
void inv1d(const complex_t* in, complex_t* out,
int size) const;

virtual
void fwd1b(const complex_t* in, complex_t* out,
int nrows, int ncols, int axis) const;

virtual
void inv1b(const complex_t* in, complex_t* out,
int nrows, int ncols, int axis) const;

virtual
void fwd2d(const complex_t* in, complex_t* out,
int nrows, int ncols) const;
virtual
void inv2d(const complex_t* in, complex_t* out,
int nrows, int ncols) const;

virtual
void transpose(const scalar_t* in, scalar_t* out,
int nrows, int ncols) const;
virtual
void transpose(const complex_t* in, complex_t* out,
int nrows, int ncols) const;

};
}

#endif
34 changes: 34 additions & 0 deletions aux/inc/WireCellAux/Semaphore.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/** Implement a semaphore component interace. */

#ifndef WIRECELLAUX_SEMAPHORE
#define WIRECELLAUX_SEMAPHORE

#include "WireCellIface/IConfigurable.h"
#include "WireCellIface/ISemaphore.h"
#include "WireCellUtil/Semaphore.h"


namespace WireCell::Aux {
class Semaphore : public ISemaphore,
public IConfigurable
{
public:
Semaphore();
virtual ~Semaphore();

// IConfigurable interface
virtual void configure(const WireCell::Configuration& config);
virtual WireCell::Configuration default_configuration() const;

// ISemaphore
virtual void acquire() const;
virtual void release() const;

private:

mutable FastSemaphore m_sem;

};
} // namespace WireCell::Pytorch

#endif // WIRECELLPYTORCH_TORCHSERVICE
22 changes: 18 additions & 4 deletions aux/inc/WireCellAux/SimpleTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
#define WIRECELL_AUX_SIMPLETENSOR

#include "WireCellIface/ITensor.h"

#include <boost/multi_array.hpp>
#include <cstring>

namespace WireCell {

Expand All @@ -13,14 +15,26 @@ namespace WireCell {
public:
typedef ElementType element_t;

SimpleTensor(const shape_t& shape)
// Create simple tensor, allocating space for data. If
// data given it must have at least as many elements as
// implied by shape and that span will be copied into
// allocated memory.
SimpleTensor(const shape_t& shape,
const element_t* data=nullptr,
const Configuration& md = Configuration())
{
size_t nbytes = element_size();
for (const auto& s : shape) {
m_shape = shape;
for (const auto& s : m_shape) {
nbytes *= s;
}
m_store.resize(nbytes);
m_shape = shape;
if (data) {
const std::byte* bytes = reinterpret_cast<const std::byte*>(data);
m_store.assign(bytes, bytes+nbytes);
}
else {
m_store.resize(nbytes);
}
}
virtual ~SimpleTensor() {}

Expand Down
77 changes: 77 additions & 0 deletions aux/inc/WireCellAux/TensorTools.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#ifndef WIRECELL_AUX_TENSORTOOLS
#define WIRECELL_AUX_TENSORTOOLS

#include "WireCellIface/ITensor.h"
#include "WireCellIface/IDFT.h"
#include "WireCellUtil/Exceptions.h"

#include <Eigen/Core>
#include <complex>

namespace WireCell::Aux {

bool is_row_major(const ITensor::pointer& ten) {
if (ten->order().empty() or ten->order()[0] == 1) {
return true;
}
return false;
}

template<typename scalar_t>
bool is_type(const ITensor::pointer& ten) {
return (ten->element_type() == typeid(scalar_t));
}


// Extract the underlying data array from the tensor as a vector.
// Caution: this ignores storage order hints and 1D or 2D will be
// flattened assuming C-ordering, aka row-major (if 2D). It
// throws ValueError on type mismatch.
template<typename element_type>
std::vector<element_type> asvec(const ITensor::pointer& ten)
{
if (ten->element_type() != typeid(element_type)) {
THROW(ValueError() << errmsg{"element type mismatch"});
}
const element_type* data = (const element_type*)ten->data();
const size_t nelems = ten->size()/sizeof(element_type);
return std::vector<element_type>(data, data+nelems);
}

// Extract the tensor data as an Eigen array.
template<typename element_type>
Eigen::Array<element_type, Eigen::Dynamic, Eigen::Dynamic> // this default is column-wise
asarray(const ITensor::pointer& tens)
{
if (tens->element_type() != typeid(element_type)) {
THROW(ValueError() << errmsg{"element type mismatch"});
}
using ROWM = Eigen::Array<element_type, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using COLM = Eigen::Array<element_type, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>;

auto shape = tens->shape();
int nrows, ncols;
if (shape.size() == 1) {
nrows = 1;
ncols = shape[0];
}
else {
nrows = shape[0];
ncols = shape[1];
}

// Eigen::Map is a non-const view of data but a copy happens
// on return. We need to temporarily break const correctness.
const element_type* cdata = reinterpret_cast<const element_type*>(tens->data());
element_type* mdata = const_cast<element_type*>(cdata);

if (is_row_major(tens)) {
return Eigen::Map<ROWM>(mdata, nrows, ncols);
}
// column-major
return Eigen::Map<COLM>(mdata, nrows, ncols);
}

}

#endif
Loading