Skip to content

Commit

Permalink
Numpy Support + Debug ops (#304)
Browse files Browse the repository at this point in the history
# Links

Addresses this issue: [(shortfin) Need debug-print support for f16
arrays and other types not supported array.array
types](#292)

# Description

This PR adds support for directly converting device_arrays to numpy
arrays. We also implement some common functions to be used in the debug
context, to achieve better visibility into the values of a device_array
at runtime.

We implement the
[ArrayProtocol](https://numpy.org/doc/stable/reference/arrays.interface.html)
when numpy is available to make casting between device_arrays and numpy
arrays easier.

We define a module named `shortfin.array.nputils` to access these
features, which is only available in public API if numpy is installed.

# Supported Functions

- debug_dump_array: Dump contents of array to debug logs
- debug_fill_array: Convert to np array, and fill with specific values
- _find_mode: Find the mode of an np.array, along a specified dimension
- debug_log_tensor_stats: Log the following stats for a tensor:
    - NaN Count
    - Shape & dtype
    - Min
    - Max
    - Mode
    - First 10 Elements
    - Last 10 Elements

---------

Co-authored-by: Xida Ren (Cedar) <[email protected]>
Co-authored-by: Stella Laurenzo <[email protected]>
  • Loading branch information
3 people authored Oct 30, 2024
1 parent 16a1bea commit 5c6ba0e
Show file tree
Hide file tree
Showing 4 changed files with 441 additions and 5 deletions.
62 changes: 57 additions & 5 deletions shortfin/python/array_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ class PyMapping {
auto it = table.find(*dtype());
if (it == table.end()) {
throw std::invalid_argument(
fmt::format("Python array.array type code not know for dtype "
fmt::format("Python array.array type code not known for dtype "
"{}: Cannot access items",
dtype()->name()));
}
Expand Down Expand Up @@ -296,7 +296,7 @@ class PyMapping {
auto it = table.find(*dtype());
if (it == table.end()) {
throw std::invalid_argument(
fmt::format("Python array.array type code not know for dtype "
fmt::format("Python array.array type code not known for dtype "
"{}: Cannot access items",
dtype()->name()));
}
Expand All @@ -316,7 +316,7 @@ class PyMapping {
auto it = table.find(*dtype());
if (it == table.end()) {
throw std::invalid_argument(
fmt::format("Python array.array type code not know for dtype "
fmt::format("Python array.array type code not known for dtype "
"{}: Cannot access items",
dtype()->name()));
}
Expand Down Expand Up @@ -495,6 +495,7 @@ void BindArray(py::module_ &m) {
py::class_<base_array>(m, "base_array")
.def_prop_ro("dtype", &base_array::dtype)
.def_prop_ro("shape", &base_array::shape);

py::class_<device_array, base_array>(m, "device_array")
.def("__init__", [](py::args, py::kwargs) {})
.def_static(
Expand Down Expand Up @@ -527,14 +528,16 @@ void BindArray(py::module_ &m) {
[](local::ScopedDevice &device, std::span<const size_t> shape,
DType dtype) {
return custom_new_keep_alive<device_array>(
py::type<device_array>(), /*keep_alive=*/device.fiber(),
py::type<device_array>(),
/*keep_alive=*/device.fiber(),
device_array::for_device(device, shape, dtype));
})
.def_static("for_host",
[](local::ScopedDevice &device, std::span<const size_t> shape,
DType dtype) {
return custom_new_keep_alive<device_array>(
py::type<device_array>(), /*keep_alive=*/device.fiber(),
py::type<device_array>(),
/*keep_alive=*/device.fiber(),
device_array::for_host(device, shape, dtype));
})
.def("for_transfer",
Expand Down Expand Up @@ -599,7 +602,56 @@ void BindArray(py::module_ &m) {
return mapping.SetItems(refs.get(), initializer);
},
DOCSTRING_ARRAY_ITEMS)
.def_prop_ro(
"__array_interface__",
[refs](device_array &self) {
py::dict interface;
interface["version"] = 3;
interface["strides"] = py::none();

auto shape = self.shape();
py::list shapeList;
for (size_t i = 0; i < shape.size(); ++i) {
shapeList.append(shape[i]);
}
py::tuple shape_tuple(py::cast(shapeList));
interface["shape"] = shape_tuple;

auto &table = refs->element_type_array_type_code_table;
auto it = table.find(self.dtype());
if (it == table.end()) {
throw std::invalid_argument(fmt::format(
"Python array.array type code not known for dtype "
"{}: Cannot access items",
self.dtype().name()));
}

auto typeString = py::str();
if (it->first == DType::float16()) {
typeString = py::str("float16");
} else {
typeString = py::str(it->second);
}
interface["typestr"] = typeString;

PyMapping *mapping;
py::object mapping_obj = CreateMappingObject(&mapping);
mapping->set_dtype(self.dtype());
self.storage().map_explicit(
mapping->mapping(), static_cast<iree_hal_memory_access_bits_t>(
IREE_HAL_MEMORY_ACCESS_READ));
auto items = mapping->GetItems(mapping_obj, refs.get());

// Obtain pointer to first element in items
Py_buffer buffer;
if (PyObject_GetBuffer(items.ptr(), &buffer, PyBUF_SIMPLE) != 0) {
throw std::runtime_error("Failed to get buffer from items");
}
void *itemsPtr = buffer.buf;
interface["data"] =
py::make_tuple(reinterpret_cast<intptr_t>(itemsPtr), false);
return interface;
})
.def("__repr__", &device_array::to_s)
.def("__str__", [](device_array &self) -> std::string {
auto contents = self.contents_to_s();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import importlib.util

from _shortfin import lib as _sfl

# All dtype aliases.
Expand Down Expand Up @@ -83,3 +85,10 @@
"fill_randn",
"RandomGenerator",
]

# Import nputils if numpy is present.
np_present = importlib.util.find_spec("numpy") is not None
if np_present:
from . import _nputils as nputils

__all__.append("nputils")
104 changes: 104 additions & 0 deletions shortfin/python/shortfin/array/_nputils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import logging

import numpy as np

from shortfin import array as sfnp

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def debug_dump_array(tensor: sfnp.device_array) -> None:
"""Dump the contents of a device array to the debug log.
Args:
tensor (sfnp.device_array): The device array to dump.
"""
np_array = np.array(tensor)
logger.debug(np_array)


def debug_fill_array(tensor: sfnp.device_array, fill_value: int | float) -> np.ndarray:
"""Fill a device array with a given value and return the resulting numpy array.
Args:
tensor (sfnp.device_array): The device array to fill.
fill_value (int | float): The value to fill the array with.
Returns:
np.ndarray: The filled numpy array.
"""
np_array = np.array(tensor)
np_array.fill(fill_value)
return np_array


def _find_mode(
arr: np.ndarray, axis=0, keepdims=False
) -> tuple[np.ndarray, np.ndarray]:
"""
Find the mode of an array along a given axis.
Args:
arr: The input array.
axis: The axis along which to find the mode.
keepdims: If True, the output shape is the same as arr except along the specified axis.
Returns:
tuple: A tuple containing the mode values and the count of the mode values.
"""

def _mode(arr):
if arr.size == 0:
return np.nan, 0

unique, counts = np.unique(arr, return_counts=True)
max_counts = counts.max()

mode = unique[counts == max_counts][0]
return mode, max_counts

result = np.apply_along_axis(_mode, axis, arr)
mode_values, mode_count = result[..., 0], result[..., 1]

if keepdims:
mode_values = np.expand_dims(mode_values, axis)
mode_count = np.expand_dims(mode_count, axis)

return mode_values, mode_count


def debug_log_tensor_stats(tensor: sfnp.device_array) -> None:
"""Log statistics about a device array to the debug log.
The following statistics are logged:
- NaN count
- Shape, dtype
- Min, max, mean, mode (excluding NaN values)
- First 10 elements
- Last 10 elements
Args:
tensor (sfnp.device_array): The device array to log statistics for.
"""

np_array = np.array(tensor)

nan_count = np.isnan(np_array).sum()

# Remove NaN values
np_array_no_nan = np_array[~np.isnan(np_array)]

logger.debug(f"NaN count: {nan_count} / {np_array.size}")
logger.debug(f"Shape: {np_array.shape}, dtype: {np_array.dtype}")

if len(np_array_no_nan) > 0:
mode = _find_mode(np_array_no_nan)[0]
logger.debug(f"Min (excluding NaN): {np_array_no_nan.min()}")
logger.debug(f"Max (excluding NaN): {np_array_no_nan.max()}")
logger.debug(f"Mean (excluding NaN): {np_array_no_nan.mean()}")
logger.debug(f"Mode (excluding NaN): {mode}")
logger.debug(f"First 10 elements: {np_array_no_nan.flatten()[:10]}")
logger.debug(f"Last 10 elements: {np_array_no_nan.flatten()[-10:]}")
else:
logger.warning(f"All values are NaN")
Loading

0 comments on commit 5c6ba0e

Please sign in to comment.