Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding advanced interface #68

Merged
merged 16 commits into from
Mar 15, 2024
5 changes: 4 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF)

# Enable CUDA if requested and available
option(JAX_FINUFFT_USE_CUDA "Enable CUDA build" OFF)

if(JAX_FINUFFT_USE_CUDA)
include(CheckLanguage)
check_language(CUDA)

if(CMAKE_CUDA_COMPILER)
message(STATUS "CUDA compiler found; compiling with GPU support")
enable_language(CUDA)
Expand All @@ -28,7 +30,7 @@ else()
set(FINUFFT_USE_CUDA OFF)
endif()

if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
# TODO(dfm): OpenMP segfaults on my system - can we enable this somehow?
set(FINUFFT_USE_OPENMP OFF)
else()
Expand Down Expand Up @@ -63,6 +65,7 @@ if(FINUFFT_USE_CUDA)
)
pybind11_add_module(jax_finufft_gpu
${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_gpu.cc
${CMAKE_CURRENT_LIST_DIR}/lib/cufinufft_wrapper.cc
${CMAKE_CURRENT_LIST_DIR}/lib/kernels.cc.cu)
target_include_directories(jax_finufft_gpu PUBLIC ${CUFINUFFT_INCLUDE_DIRS})
target_include_directories(jax_finufft_gpu PUBLIC ${CUFINUFFT_VENDORED_INCLUDE_DIRS})
Expand Down
55 changes: 12 additions & 43 deletions lib/jax_finufft_gpu.h → lib/cufinufft_wrapper.cc
Original file line number Diff line number Diff line change
@@ -1,37 +1,30 @@
#ifndef _JAX_FINUFFT_GPU_H_
#define _JAX_FINUFFT_GPU_H_
#include "cufinufft_wrapper.h"

#include <complex>

#include "cufinufft.h"

namespace jax_finufft {

template <typename T>
struct plan_type;
namespace gpu {

template <>
struct plan_type<double> {
typedef cufinufft_plan type;
};
void default_opts<float>(cufinufft_opts* opts) {
cufinufft_default_opts(opts);
}

template <>
struct plan_type<float> {
typedef cufinufftf_plan type;
};

template <typename T>
void default_opts(int type, int dim, cufinufft_opts* opts, cudaStream_t stream);
void default_opts<double>(cufinufft_opts* opts) {
cufinufft_default_opts(opts);
}

template <>
void default_opts<float>(int type, int dim, cufinufft_opts* opts, cudaStream_t stream) {
cufinufft_default_opts(opts);
void update_opts<float>(cufinufft_opts* opts, int dim, cudaStream_t stream) {
opts->gpu_stream = stream;
}

template <>
void default_opts<double>(int type, int dim, cufinufft_opts* opts, cudaStream_t stream) {
cufinufft_default_opts(opts);
void update_opts<double>(cufinufft_opts* opts, int dim, cudaStream_t stream) {
opts->gpu_stream = stream;

// double precision in 3D blows out shared memory.
Expand All @@ -42,10 +35,6 @@ void default_opts<double>(int type, int dim, cufinufft_opts* opts, cudaStream_t
}
}

template <typename T>
int makeplan(int type, int dim, const int64_t nmodes[3], int iflag, int ntr, T eps,
typename plan_type<T>::type* plan, cufinufft_opts* opts);

template <>
int makeplan<float>(int type, int dim, const int64_t nmodes[3], int iflag, int ntr, float eps,
typename plan_type<float>::type* plan, cufinufft_opts* opts) {
Expand All @@ -61,10 +50,6 @@ int makeplan<double>(int type, int dim, const int64_t nmodes[3], int iflag, int
return cufinufft_makeplan(type, dim, tmp_nmodes, iflag, ntr, eps, plan, opts);
}

template <typename T>
int setpts(typename plan_type<T>::type plan, int64_t M, T* x, T* y, T* z, int64_t N, T* s, T* t,
T* u);

template <>
int setpts<float>(typename plan_type<float>::type plan, int64_t M, float* x, float* y, float* z,
int64_t N, float* s, float* t, float* u) {
Expand All @@ -77,9 +62,6 @@ int setpts<double>(typename plan_type<double>::type plan, int64_t M, double* x,
return cufinufft_setpts(plan, M, x, y, z, N, s, t, u);
}

template <typename T>
int execute(typename plan_type<T>::type plan, std::complex<T>* c, std::complex<T>* f);

template <>
int execute<float>(typename plan_type<float>::type plan, std::complex<float>* c,
std::complex<float>* f) {
Expand All @@ -96,9 +78,6 @@ int execute<double>(typename plan_type<double>::type plan, std::complex<double>*
return cufinufft_execute(plan, _c, _f);
}

template <typename T>
void destroy(typename plan_type<T>::type plan);

template <>
void destroy<float>(typename plan_type<float>::type plan) {
cufinufftf_destroy(plan);
Expand All @@ -109,11 +88,6 @@ void destroy<double>(typename plan_type<double>::type plan) {
cufinufft_destroy(plan);
}

template <int ndim, typename T>
T* y_index(T* y, int64_t index) {
return &(y[index]);
}

template <>
double* y_index<1, double>(double* y, int64_t index) {
return NULL;
Expand All @@ -124,11 +98,6 @@ float* y_index<1, float>(float* y, int64_t index) {
return NULL;
}

template <int ndim, typename T>
T* z_index(T* z, int64_t index) {
return NULL;
}

template <>
double* z_index<3, double>(double* z, int64_t index) {
return &(z[index]);
Expand All @@ -139,6 +108,6 @@ float* z_index<3, float>(float* z, int64_t index) {
return &(z[index]);
}

} // namespace jax_finufft
} // namespace gpu

#endif
} // namespace jax_finufft
71 changes: 71 additions & 0 deletions lib/cufinufft_wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#ifndef _CUFINUFFT_WRAPPER_H_
#define _CUFINUFFT_WRAPPER_H_

#include <complex>

#include "cufinufft.h"

namespace jax_finufft {

namespace gpu {

template <typename T>
struct plan_type;

template <>
struct plan_type<double> {
typedef cufinufft_plan type;
};

template <>
struct plan_type<float> {
typedef cufinufftf_plan type;
};

template <typename T>
void default_opts(cufinufft_opts* opts);

template <typename T>
void update_opts(cufinufft_opts* opts, int dim, cudaStream_t stream);

template <typename T>
int makeplan(int type, int dim, const int64_t nmodes[3], int iflag, int ntr, T eps,
typename plan_type<T>::type* plan, cufinufft_opts* opts);

template <typename T>
int setpts(typename plan_type<T>::type plan, int64_t M, T* x, T* y, T* z, int64_t N, T* s, T* t,
T* u);

template <typename T>
int execute(typename plan_type<T>::type plan, std::complex<T>* c, std::complex<T>* f);

template <typename T>
void destroy(typename plan_type<T>::type plan);

template <int ndim, typename T>
T* y_index(T* y, int64_t index) {
return &(y[index]);
}

template <int ndim, typename T>
T* z_index(T* z, int64_t index) {
return NULL;
}

template <>
double* y_index<1, double>(double* y, int64_t index);

template <>
float* y_index<1, float>(float* y, int64_t index);

template <>
double* z_index<3, double>(double* z, int64_t index);

template <>
float* z_index<3, float>(float* z, int64_t index);

} // namespace gpu

} // namespace jax_finufft

#endif
21 changes: 0 additions & 21 deletions lib/jax_finufft_common.h

This file was deleted.

80 changes: 64 additions & 16 deletions lib/jax_finufft_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,32 @@
#include "pybind11_kernel_helpers.h"

using namespace jax_finufft;
using namespace jax_finufft::cpu;
namespace py = pybind11;

namespace {

template <int ndim, typename T>
void run_nufft(int type, void *desc_in, T *x, T *y, T *z, std::complex<T> *c, std::complex<T> *F) {
const NufftDescriptor<T> *descriptor = unpack_descriptor<NufftDescriptor<T>>(
reinterpret_cast<const char *>(desc_in), sizeof(NufftDescriptor<T>));
const descriptor<T> *desc = unpack_descriptor<descriptor<T>>(
reinterpret_cast<const char *>(desc_in), sizeof(descriptor<T>));
int64_t n_k = 1;
for (int d = 0; d < ndim; ++d) n_k *= descriptor->n_k[d];

finufft_opts *opts = new finufft_opts;
default_opts<T>(opts);
for (int d = 0; d < ndim; ++d) n_k *= desc->n_k[d];
finufft_opts opts = desc->opts;

typename plan_type<T>::type plan;
makeplan<T>(type, ndim, const_cast<int64_t *>(descriptor->n_k), descriptor->iflag,
descriptor->n_transf, descriptor->eps, &plan, opts);
for (int64_t index = 0; index < descriptor->n_tot; ++index) {
int64_t i = index * descriptor->n_j;
int64_t j = i * descriptor->n_transf;
int64_t k = index * n_k * descriptor->n_transf;

setpts<T>(plan, descriptor->n_j, &(x[i]), y_index<ndim, T>(y, i), z_index<ndim, T>(z, i), 0,
NULL, NULL, NULL);
makeplan<T>(type, ndim, const_cast<int64_t *>(desc->n_k), desc->iflag, desc->n_transf, desc->eps,
&plan, &opts);
for (int64_t index = 0; index < desc->n_tot; ++index) {
int64_t i = index * desc->n_j;
int64_t j = i * desc->n_transf;
int64_t k = index * n_k * desc->n_transf;

setpts<T>(plan, desc->n_j, &(x[i]), y_index<ndim, T>(y, i), z_index<ndim, T>(z, i), 0, NULL,
NULL, NULL);
execute<T>(plan, &c[j], &F[k]);
}
destroy<T>(plan);
delete opts;
}

template <int ndim, typename T>
Expand Down Expand Up @@ -68,6 +67,40 @@ void nufft2(void *out, void **in) {
run_nufft<ndim, T>(2, in[0], x, y, z, c, F);
}

template <typename T>
py::bytes build_descriptor(T eps, int iflag, int64_t n_tot, int n_transf, int64_t n_j,
int64_t n_k_1, int64_t n_k_2, int64_t n_k_3, finufft_opts opts) {
return pack_descriptor(
descriptor<T>{eps, iflag, n_tot, n_transf, n_j, {n_k_1, n_k_2, n_k_3}, opts});
}

template <typename T>
finufft_opts *build_opts(bool modeord, bool chkbnds, int debug, int spread_debug, bool showwarn,
int nthreads, int fftw, int spread_sort, bool spread_kerevalmeth,
bool spread_kerpad, double upsampfac, int spread_thread, int maxbatchsize,
int spread_nthr_atomic, int spread_max_sp_size) {
finufft_opts *opts = new finufft_opts;
default_opts<T>(opts);

opts->modeord = int(modeord);
opts->chkbnds = int(chkbnds);
opts->debug = debug;
opts->spread_debug = spread_debug;
opts->showwarn = int(showwarn);
opts->nthreads = nthreads;
opts->fftw = fftw;
opts->spread_sort = spread_sort;
opts->spread_kerevalmeth = int(spread_kerevalmeth);
opts->spread_kerpad = int(spread_kerpad);
opts->upsampfac = upsampfac;
opts->spread_thread = int(spread_thread);
opts->maxbatchsize = maxbatchsize;
opts->spread_nthr_atomic = spread_nthr_atomic;
opts->spread_max_sp_size = spread_max_sp_size;

return opts;
}

pybind11::dict Registrations() {
pybind11::dict dict;

Expand All @@ -92,6 +125,21 @@ PYBIND11_MODULE(jax_finufft_cpu, m) {
m.def("registrations", &Registrations);
m.def("build_descriptorf", &build_descriptor<float>);
m.def("build_descriptor", &build_descriptor<double>);

m.attr("FFTW_ESTIMATE") = py::int_(FFTW_ESTIMATE);
m.attr("FFTW_MEASURE") = py::int_(FFTW_MEASURE);
m.attr("FFTW_PATIENT") = py::int_(FFTW_PATIENT);
m.attr("FFTW_EXHAUSTIVE") = py::int_(FFTW_EXHAUSTIVE);
m.attr("FFTW_WISDOM_ONLY") = py::int_(FFTW_WISDOM_ONLY);

py::class_<finufft_opts> opts(m, "FinufftOpts");
opts.def(py::init(&build_opts<double>), py::arg("modeord") = false, py::arg("chkbnds") = true,
py::arg("debug") = 0, py::arg("spread_debug") = 0, py::arg("showwarn") = false,
py::arg("nthreads") = 0, py::arg("fftw") = int(FFTW_ESTIMATE),
py::arg("spread_sort") = 2, py::arg("spread_kerevalmeth") = true,
py::arg("spread_kerpad") = true, py::arg("upsampfac") = 0.0,
py::arg("spread_thread") = 0, py::arg("maxbatchsize") = 0,
py::arg("spread_nthr_atomic") = -1, py::arg("spread_max_sp_size") = 0);
}

} // namespace
16 changes: 16 additions & 0 deletions lib/jax_finufft_cpu.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
#ifndef _JAX_FINUFFT_H_
#define _JAX_FINUFFT_H_

#include <fftw3.h>

#include <complex>

#include "finufft.h"

namespace jax_finufft {

namespace cpu {

template <typename T>
struct plan_type;

Expand Down Expand Up @@ -123,6 +127,18 @@ float* z_index<3, float>(float* z, int64_t index) {
return &(z[index]);
}

template <typename T>
struct descriptor {
T eps;
int iflag;
int64_t n_tot;
int n_transf;
int64_t n_j;
int64_t n_k[3];
finufft_opts opts;
};

} // namespace cpu
} // namespace jax_finufft

#endif
Loading