Skip to content

Commit

Permalink
Merge pull request #71 from bluescarni/pr/taylor_dc
Browse files Browse the repository at this point in the history
Extension to the Taylor decomposition
  • Loading branch information
bluescarni authored Jan 5, 2021
2 parents 0c66b28 + cf56114 commit 62d956c
Show file tree
Hide file tree
Showing 41 changed files with 1,944 additions and 590 deletions.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ endif()
# List of source files.
set(HEYOKA_SRC_FILES
"${CMAKE_CURRENT_SOURCE_DIR}/src/llvm_state.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/number.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/binary_operator.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/number.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/variable.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/func.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/param.cpp"
Expand All @@ -178,6 +178,7 @@ set(HEYOKA_SRC_FILES
"${CMAKE_CURRENT_SOURCE_DIR}/src/math/sin.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/math/sqrt.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/math/square.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/math/tan.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/detail/string_conv.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/detail/math_wrappers.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/detail/llvm_helpers.cpp"
Expand Down
6 changes: 5 additions & 1 deletion doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@ Changelog
New
~~~

- Extend the Taylor decomposition machinery to work
on more general classes of functions, and add
``tan()``
(`#71 <https://github.com/bluescarni/heyoka/pull/71>`__).
- Implement support for runtime parameters
(`#68 <https://github.com/bluescarni/heyoka/pull/68>`__).
- Initial tutorials and various documentation additions
(`#63 <https://github.com/bluescarni/heyoka/pull/63>`__).
- Add stream operator for the ``taylor_outcome`` enum
- Add a stream operator for the ``taylor_outcome`` enum
(`#63 <https://github.com/bluescarni/heyoka/pull/63>`__).

Changes
Expand Down
2 changes: 1 addition & 1 deletion doc/tut_expression_system.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Note that support for extended-precision floating-point types
In addition to the standard mathematical operators, heyoka's expression system
also supports the following elementary functions (with more to come in the near future):

* sine and cosine,
* sine, cosine and tangent,
* logarithm and exponential,
* exponentiation,
* square root.
Expand Down
32 changes: 18 additions & 14 deletions include/heyoka/binary_operator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <string>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <vector>

#if defined(HEYOKA_HAVE_REAL128)
Expand Down Expand Up @@ -98,37 +99,40 @@ HEYOKA_DLL_PUBLIC void update_grad_dbl(std::unordered_map<std::string, double> &
const std::unordered_map<std::string, double> &, const std::vector<double> &,
const std::vector<std::vector<std::size_t>> &, std::size_t &, double);

HEYOKA_DLL_PUBLIC std::vector<expression>::size_type taylor_decompose_in_place(binary_operator &&,
std::vector<expression> &);
HEYOKA_DLL_PUBLIC std::vector<std::pair<expression, std::vector<std::uint32_t>>>::size_type
taylor_decompose_in_place(binary_operator &&, std::vector<std::pair<expression, std::vector<std::uint32_t>>> &);

HEYOKA_DLL_PUBLIC llvm::Value *taylor_diff_dbl(llvm_state &, const binary_operator &,
const std::vector<llvm::Value *> &, llvm::Value *, std::uint32_t,
std::uint32_t, std::uint32_t, std::uint32_t);
const std::vector<std::uint32_t> &, const std::vector<llvm::Value *> &,
llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t,
std::uint32_t);

HEYOKA_DLL_PUBLIC llvm::Value *taylor_diff_ldbl(llvm_state &, const binary_operator &,
const std::vector<llvm::Value *> &, llvm::Value *, std::uint32_t,
std::uint32_t, std::uint32_t, std::uint32_t);
const std::vector<std::uint32_t> &, const std::vector<llvm::Value *> &,
llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t,
std::uint32_t);

#if defined(HEYOKA_HAVE_REAL128)

HEYOKA_DLL_PUBLIC llvm::Value *taylor_diff_f128(llvm_state &, const binary_operator &,
const std::vector<llvm::Value *> &, llvm::Value *, std::uint32_t,
std::uint32_t, std::uint32_t, std::uint32_t);
const std::vector<std::uint32_t> &, const std::vector<llvm::Value *> &,
llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t,
std::uint32_t);

#endif

template <typename T>
inline llvm::Value *taylor_diff(llvm_state &s, const binary_operator &bo, const std::vector<llvm::Value *> &arr,
llvm::Value *par_ptr, std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
std::uint32_t batch_size)
inline llvm::Value *taylor_diff(llvm_state &s, const binary_operator &bo, const std::vector<std::uint32_t> &deps,
const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, std::uint32_t n_uvars,
std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size)
{
if constexpr (std::is_same_v<T, double>) {
return taylor_diff_dbl(s, bo, arr, par_ptr, n_uvars, order, idx, batch_size);
return taylor_diff_dbl(s, bo, deps, arr, par_ptr, n_uvars, order, idx, batch_size);
} else if constexpr (std::is_same_v<T, long double>) {
return taylor_diff_ldbl(s, bo, arr, par_ptr, n_uvars, order, idx, batch_size);
return taylor_diff_ldbl(s, bo, deps, arr, par_ptr, n_uvars, order, idx, batch_size);
#if defined(HEYOKA_HAVE_REAL128)
} else if constexpr (std::is_same_v<T, mppp::real128>) {
return taylor_diff_f128(s, bo, arr, par_ptr, n_uvars, order, idx, batch_size);
return taylor_diff_f128(s, bo, deps, arr, par_ptr, n_uvars, order, idx, batch_size);
#endif
} else {
static_assert(detail::always_false_v<T>, "Unhandled type.");
Expand Down
38 changes: 0 additions & 38 deletions include/heyoka/detail/math_wrappers.hpp

This file was deleted.

6 changes: 5 additions & 1 deletion include/heyoka/detail/taylor_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include <string>
#include <vector>

#include <boost/numeric/conversion/cast.hpp>

#include <llvm/IR/BasicBlock.h>
#include <llvm/IR/DerivedTypes.h>
#include <llvm/IR/Function.h>
Expand All @@ -36,7 +38,7 @@ namespace heyoka::detail
template <typename T, typename F, typename U>
inline llvm::Function *taylor_c_diff_func_unary_num_det(llvm_state &s, const F &fn, const U &n,
std::uint32_t batch_size, const std::string &fname,
const std::string &desc)
const std::string &desc, std::uint32_t n_deps = 0)
{
auto &module = s.module();
auto &builder = s.builder();
Expand All @@ -54,6 +56,8 @@ inline llvm::Function *taylor_c_diff_func_unary_num_det(llvm_state &s, const F &
std::vector<llvm::Type *> fargs{
llvm::Type::getInt32Ty(context), llvm::Type::getInt32Ty(context), llvm::PointerType::getUnqual(val_t),
llvm::PointerType::getUnqual(to_llvm_type<T>(context)), taylor_c_diff_numparam_argtype<T>(s, n)};
// Add the hidden deps at the end.
fargs.insert(fargs.end(), boost::numeric_cast<decltype(fargs.size())>(n_deps), llvm::Type::getInt32Ty(context));

// Try to see if we already created the function.
auto f = module.getFunction(fname);
Expand Down
54 changes: 35 additions & 19 deletions include/heyoka/expression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <cstdint>
#include <functional>
#include <ostream>
#include <stdexcept>
#include <string>
#include <type_traits>
#include <unordered_map>
Expand Down Expand Up @@ -105,9 +106,9 @@ HEYOKA_DLL_PUBLIC expression operator""_var(const char *, std::size_t);
namespace detail
{

// NOTE: this needs to go here because
// NOTE: these need to go here because
// the definition of expression must be visible
// in order for this to be well-formed.
// in order for these to be well-formed.
template <typename T>
inline expression func_inner<T>::diff(const std::string &s) const
{
Expand All @@ -118,6 +119,21 @@ inline expression func_inner<T>::diff(const std::string &s) const
}
}

template <typename T>
inline std::vector<std::pair<expression, std::vector<std::uint32_t>>>::size_type
func_inner<T>::taylor_decompose(std::vector<std::pair<expression, std::vector<std::uint32_t>>> &u_vars_defs) &&
{
if constexpr (func_has_taylor_decompose_v<T>) {
return std::move(m_value).taylor_decompose(u_vars_defs);
} else {
func_default_td_impl(static_cast<func_base &>(m_value), u_vars_defs);

u_vars_defs.emplace_back(func{std::move(m_value)}, std::vector<std::uint32_t>{});

return u_vars_defs.size() - 1u;
}
}

struct HEYOKA_DLL_PUBLIC prime_wrapper {
std::string m_str;

Expand Down Expand Up @@ -267,43 +283,43 @@ HEYOKA_DLL_PUBLIC void update_grad_dbl(std::unordered_map<std::string, double> &
const std::unordered_map<std::string, double> &, const std::vector<double> &,
const std::vector<std::vector<std::size_t>> &, std::size_t &, double = 1.);

HEYOKA_DLL_PUBLIC std::vector<expression>::size_type taylor_decompose_in_place(expression &&,
std::vector<expression> &);
HEYOKA_DLL_PUBLIC std::vector<std::pair<expression, std::vector<std::uint32_t>>>::size_type
taylor_decompose_in_place(expression &&, std::vector<std::pair<expression, std::vector<std::uint32_t>>> &);

template <typename... Args>
inline std::array<expression, sizeof...(Args)> make_vars(const Args &...strs)
{
return std::array{expression{variable{strs}}...};
}

HEYOKA_DLL_PUBLIC llvm::Value *taylor_diff_dbl(llvm_state &, const expression &, const std::vector<llvm::Value *> &,
llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t,
std::uint32_t);
HEYOKA_DLL_PUBLIC llvm::Value *taylor_diff_dbl(llvm_state &, const expression &, const std::vector<std::uint32_t> &,
const std::vector<llvm::Value *> &, llvm::Value *, std::uint32_t,
std::uint32_t, std::uint32_t, std::uint32_t);

HEYOKA_DLL_PUBLIC llvm::Value *taylor_diff_ldbl(llvm_state &, const expression &, const std::vector<llvm::Value *> &,
llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t,
std::uint32_t);
HEYOKA_DLL_PUBLIC llvm::Value *taylor_diff_ldbl(llvm_state &, const expression &, const std::vector<std::uint32_t> &,
const std::vector<llvm::Value *> &, llvm::Value *, std::uint32_t,
std::uint32_t, std::uint32_t, std::uint32_t);

#if defined(HEYOKA_HAVE_REAL128)

HEYOKA_DLL_PUBLIC llvm::Value *taylor_diff_f128(llvm_state &, const expression &, const std::vector<llvm::Value *> &,
llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t,
std::uint32_t);
HEYOKA_DLL_PUBLIC llvm::Value *taylor_diff_f128(llvm_state &, const expression &, const std::vector<std::uint32_t> &,
const std::vector<llvm::Value *> &, llvm::Value *, std::uint32_t,
std::uint32_t, std::uint32_t, std::uint32_t);

#endif

template <typename T>
inline llvm::Value *taylor_diff(llvm_state &s, const expression &ex, const std::vector<llvm::Value *> &arr,
llvm::Value *par_ptr, std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
std::uint32_t batch_size)
inline llvm::Value *taylor_diff(llvm_state &s, const expression &ex, const std::vector<std::uint32_t> &deps,
const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, std::uint32_t n_uvars,
std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size)
{
if constexpr (std::is_same_v<T, double>) {
return taylor_diff_dbl(s, ex, arr, par_ptr, n_uvars, order, idx, batch_size);
return taylor_diff_dbl(s, ex, deps, arr, par_ptr, n_uvars, order, idx, batch_size);
} else if constexpr (std::is_same_v<T, long double>) {
return taylor_diff_ldbl(s, ex, arr, par_ptr, n_uvars, order, idx, batch_size);
return taylor_diff_ldbl(s, ex, deps, arr, par_ptr, n_uvars, order, idx, batch_size);
#if defined(HEYOKA_HAVE_REAL128)
} else if constexpr (std::is_same_v<T, mppp::real128>) {
return taylor_diff_f128(s, ex, arr, par_ptr, n_uvars, order, idx, batch_size);
return taylor_diff_f128(s, ex, deps, arr, par_ptr, n_uvars, order, idx, batch_size);
#endif
} else {
static_assert(detail::always_false_v<T>, "Unhandled type.");
Expand Down
Loading

0 comments on commit 62d956c

Please sign in to comment.