Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom dtype improvements #742

Merged
merged 1 commit into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Version 2.2.0 (TBA)
`#732 <https://github.com/wjakob/nanobind/pull/732>`__)

* A refactor of :cpp:class:`nb::ndarray\<...\> <ndarray>` was an opportunity to
realize two usability improvements:
realize three usability improvements:

1. The constructor used to return new nd-arrays from C++ now considers
all template arguments:
Expand Down Expand Up @@ -96,6 +96,12 @@ Version 2.2.0 (TBA)
values. This is useful to :ref:`return temporaries (e.g. stack-allocated
memory) <ndarray-temporaries>` from functions.

3. Added a new and more general mechanism ``nanobind::detail::dtype_traits<T>``
to declare custom ndarray data types like ``float16`` or ``bfloat16``. The old
interface (``nanobind::ndarray_traits<T>``) still exists but is deprecated
and will be removed in the next major release. See the :ref:`documentation
<ndarray-nonstandard>` for details.

There are two minor but potentially breaking changes:

1. The nd-array type caster now interprets the
Expand Down
29 changes: 18 additions & 11 deletions docs/ndarray.rst
Original file line number Diff line number Diff line change
Expand Up @@ -519,22 +519,29 @@ Nonstandard arithmetic types
----------------------------

Low or extended-precision arithmetic types (e.g., ``int128``, ``float16``,
``bfloat``) are sometimes used but don't have standardized C++ equivalents. If
you wish to exchange arrays based on such types, you must register a partial
overload of ``nanobind::ndarray_traits`` to inform nanobind about it.
``bfloat16``) are sometimes used but don't have standardized C++ equivalents.
If you wish to exchange arrays based on such types, you must register a partial
overload of ``nanobind::detail::dtype_traits`` to inform nanobind about it.

You are expressively allowed to create partial overloads of this class despite
it being in the ``nanobind::detail`` namespace.

For example, the following snippet makes ``__fp16`` (half-precision type on
``aarch64``) available:
``aarch64``) available by providing

1. ``value``, a DLPack ``nanobind::dlpack::dtype`` type descriptor, and
2. ``name``, a type name for use in docstrings and error messages.

.. code-block:: cpp

namespace nanobind {
template <> struct ndarray_traits<__fp16> {
static constexpr bool is_complex = false;
static constexpr bool is_float = true;
static constexpr bool is_bool = false;
static constexpr bool is_int = false;
static constexpr bool is_signed = true;
namespace nanobind::detail {
template <> struct dtype_traits<__fp16> {
static constexpr dlpack::dtype value {
(uint8_t) dlpack::dtype_code::Float, // type code
16, // size in bits
1 // lanes (simd), usually set to 1
};
static constexpr auto name = const_name("float16");
};
}

Expand Down
77 changes: 39 additions & 38 deletions include/nanobind/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
The API below is based on the DLPack project
(https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h)
*/

#pragma once

#include <nanobind/nanobind.h>
Expand Down Expand Up @@ -106,6 +107,41 @@ template <typename T> struct ndarray_traits {

NAMESPACE_BEGIN(detail)

template <typename T, typename /* SFINAE */ = int> struct dtype_traits {
using traits = ndarray_traits<T>;

static constexpr int matches = traits::is_bool + traits::is_complex +
traits::is_float + traits::is_int;
static_assert(matches <= 1, "dtype matches multiple type categories!");

static constexpr dlpack::dtype value{
(uint8_t) ((traits::is_bool ? (int) dlpack::dtype_code::Bool : 0) +
(traits::is_complex ? (int) dlpack::dtype_code::Complex : 0) +
(traits::is_float ? (int) dlpack::dtype_code::Float : 0) +
(traits::is_int && traits::is_signed ? (int) dlpack::dtype_code::Int : 0) +
(traits::is_int && !traits::is_signed ? (int) dlpack::dtype_code::UInt : 0)),
(uint8_t) matches ? sizeof(T) * 8 : 0,
matches ? 1 : 0
};

static constexpr auto name =
const_name<traits::is_complex>("complex", "") +
const_name<traits::is_int && traits::is_signed>("int", "") +
const_name<traits::is_int && !traits::is_signed>("uint", "") +
const_name<traits::is_float>("float", "") +
const_name<traits::is_bool>(const_name("bool"), const_name<sizeof(T) * 8>());
};

template <> struct dtype_traits<void> {
static constexpr dlpack::dtype value{ 0, 0, 0 };
static constexpr auto name = descr<0>();
};

template <> struct dtype_traits<const void> {
static constexpr dlpack::dtype value{ 0, 0, 0 };
static constexpr auto name = descr<0>();
};

template <ssize_t... Is> struct shape {
static constexpr auto name =
const_name("shape=(") +
Expand All @@ -130,9 +166,7 @@ template <ssize_t... Is> struct shape {
};

template <typename T>
constexpr bool is_ndarray_scalar_v =
ndarray_traits<T>::is_float || ndarray_traits<T>::is_int ||
ndarray_traits<T>::is_bool || ndarray_traits<T>::is_complex;
constexpr bool is_ndarray_scalar_v = dtype_traits<T>::value.bits != 0;

template <typename> struct ndim_shape;
template <size_t... S> struct ndim_shape<std::index_sequence<S...>> {
Expand All @@ -149,32 +183,7 @@ template <size_t N>
using ndim = typename detail::ndim_shape<std::make_index_sequence<N>>::type;

template <typename T> constexpr dlpack::dtype dtype() {
using traits = ndarray_traits<T>;

static_assert(
detail::is_ndarray_scalar_v<T> || std::is_void_v<T>,
"nanobind::dtype<T>: T must be a floating point or integer type!"
);

dlpack::dtype result;

if constexpr (!std::is_void_v<T>) {
if constexpr (traits::is_float)
result.code = (uint8_t) dlpack::dtype_code::Float;
else if constexpr (traits::is_complex)
result.code = (uint8_t) dlpack::dtype_code::Complex;
else if constexpr (traits::is_bool)
result.code = (uint8_t) dlpack::dtype_code::Bool;
else if constexpr (traits::is_signed)
result.code = (uint8_t) dlpack::dtype_code::Int;
else
result.code = (uint8_t) dlpack::dtype_code::UInt;

result.bits = sizeof(T) * 8;
result.lanes = 1;
}

return result;
return detail::dtype_traits<T>::value;
}

NAMESPACE_BEGIN(detail)
Expand Down Expand Up @@ -509,15 +518,7 @@ inline bool ndarray_check(handle h) { return detail::ndarray_check(h.ptr()); }
NAMESPACE_BEGIN(detail)

template <typename T> struct dtype_name {
using traits = ndarray_traits<T>;

static constexpr auto name =
const_name("dtype=") +
const_name<traits::is_complex>("complex", "") +
const_name<traits::is_int && traits::is_signed>("int", "") +
const_name<traits::is_int && !traits::is_signed>("uint", "") +
const_name<traits::is_float>("float", "") +
const_name<traits::is_bool>(const_name("bool"), const_name<sizeof(T) * 8>());
static constexpr auto name = detail::const_name("dtype=") + dtype_traits<T>::name;
};

template <> struct dtype_name<void> : unused { };
Expand Down
17 changes: 9 additions & 8 deletions tests/test_ndarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ static float f_global[] { 1, 2, 3, 4, 5, 6, 7, 8 };
static int i_global[] { 1, 2, 3, 4, 5, 6, 7, 8 };

#if defined(__aarch64__)
namespace nanobind {
template <> struct ndarray_traits<__fp16> {
static constexpr bool is_complex = false;
static constexpr bool is_float = true;
static constexpr bool is_bool = false;
static constexpr bool is_int = false;
static constexpr bool is_signed = true;
};
namespace nanobind::detail {
template <> struct dtype_traits<__fp16> {
static constexpr dlpack::dtype value {
(uint8_t) dlpack::dtype_code::Float, // type code
16, // size in bits
1 // lanes (simd)
};
static constexpr auto name = const_name("float16");
};
}
#endif

Expand Down