Skip to content

Commit

Permalink
Restore PJRT to minimally functional after break with XLA. (iree-org#…
Browse files Browse the repository at this point in the history
…15105)

When imported, we removed all code that relied on static/bazel linkage
into XLA internals. This patch adds back the pure C path for device
memory layout, which is the minimum needed to run simple programs.

Next, we will need to add code generation for layout changes, and then
it is just minor ergonomic work to get back to where it was. As-is,
though, we can run things with this patch again and can start wiring it
for CI, etc.
  • Loading branch information
stellaraccident authored Oct 5, 2023
1 parent 39b3b24 commit d76a104
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 191 deletions.
2 changes: 1 addition & 1 deletion integrations/pjrt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ python -c "import jax; a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9]); print

## Install the plugin of your choice (in this example 'cpu')

pip install -e -v --no-deps python_packages/iree_cpu_plugin
pip install -v --no-deps -e python_packages/iree_cpu_plugin

## Verify basic functionality

Expand Down
1 change: 0 additions & 1 deletion integrations/pjrt/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
iree-compiler==20230922.653
jaxlib==0.4.17.dev20230922
-e ../jax
2 changes: 2 additions & 0 deletions integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ iree_cc_library(
"api_impl.h"
"dylib_entry_point.cc.inc"
"iree_helpers.h"
"layout_utils.h"
"platform.h"
"tensor_utils.h"
SRCS
"api_impl.cc"
"layout_utils.cc"
"platform.cc"
"tensor_utils.cc"
DEPS
Expand Down
278 changes: 105 additions & 173 deletions integrations/pjrt/src/iree_pjrt/common/api_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,87 +25,6 @@ const std::string_view kMlirFormat = "mlir";
namespace PJRTApiConverter {
namespace {

// Enum converter functions
// TODO: Remove once not using xla::Shape.
// iree_status_t MapElementTypeToXlaElementType(
// iree_hal_element_type_t element_type, xla::PrimitiveType* xla_primitive)
// {
// // TODO: Cascade on bit-field sub-types to avoid large linear scan.
// switch (element_type) {
// // TODO: How do I interpret signless?
// case IREE_HAL_ELEMENT_TYPE_BOOL_8:
// *xla_primitive = xla::PrimitiveType::PRED;
// return iree_ok_status();
// case IREE_HAL_ELEMENT_TYPE_INT_4:
// *xla_primitive = xla::PrimitiveType::S4;
// return iree_ok_status();
// case IREE_HAL_ELEMENT_TYPE_INT_8:
// *xla_primitive = xla::PrimitiveType::S8;
// return iree_ok_status();
// case IREE_HAL_ELEMENT_TYPE_INT_16:
// *xla_primitive = xla::PrimitiveType::S16;
// return iree_ok_status();
// case IREE_HAL_ELEMENT_TYPE_INT_32:
// *xla_primitive = xla::PrimitiveType::S32;
// return iree_ok_status();
// case IREE_HAL_ELEMENT_TYPE_INT_64:
// *xla_primitive = xla::PrimitiveType::S64;
// return iree_ok_status();
// case IREE_HAL_ELEMENT_TYPE_SINT_4:
// *xla_primitive = xla::PrimitiveType::S4;
// return iree_ok_status();
// case IREE_HAL_ELEMENT_TYPE_SINT_8:
// *xla_primitive = xla::PrimitiveType::S8;
// return iree_ok_status();
// case IREE_HAL_ELEMENT_TYPE_SINT_16:
// *xla_primitive = xla::PrimitiveType::S16;
// return iree_ok_status();
// case IREE_HAL_ELEMENT_TYPE_SINT_32:
// *xla_primitive = xla::PrimitiveType::S32;
// return iree_ok_status();
// case IREE_HAL_ELEMENT_TYPE_SINT_64:
// *xla_primitive = xla::PrimitiveType::S64;
// return iree_ok_status();
// case IREE_HAL_ELEMENT_TYPE_UINT_4:
// *xla_primitive = xla::PrimitiveType::U4;
// return iree_ok_status();
// case IREE_HAL_ELEMENT_TYPE_UINT_8:
// *xla_primitive = xla::PrimitiveType::U8;
// return iree_ok_status();
// case IREE_HAL_ELEMENT_TYPE_UINT_16:
// *xla_primitive = xla::PrimitiveType::U16;
// return iree_ok_status();
// case IREE_HAL_ELEMENT_TYPE_UINT_32:
// *xla_primitive = xla::PrimitiveType::U32;
// return iree_ok_status();
// case IREE_HAL_ELEMENT_TYPE_UINT_64:
// *xla_primitive = xla::PrimitiveType::U64;
// return iree_ok_status();
// case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
// *xla_primitive = xla::PrimitiveType::F16;
// return iree_ok_status();
// case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
// *xla_primitive = xla::PrimitiveType::F32;
// return iree_ok_status();
// case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
// *xla_primitive = xla::PrimitiveType::F64;
// return iree_ok_status();
// case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
// *xla_primitive = xla::PrimitiveType::BF16;
// return iree_ok_status();
// case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64:
// *xla_primitive = xla::PrimitiveType::C64;
// return iree_ok_status();
// case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128:
// *xla_primitive = xla::PrimitiveType::C128;
// return iree_ok_status();
// default:
// return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
// "conversion from unknown element type 0x%x",
// (int)element_type);
// }
// }

iree_status_t MapBufferTypeToElementType(
PJRT_Buffer_Type buffer_type, iree_hal_element_type_t* element_type) {
switch (buffer_type) {
Expand Down Expand Up @@ -281,65 +200,35 @@ const std::string& ErrorInstance::message() const {

BufferInstance::~BufferInstance() = default;

// TODO: Excise.
// iree_status_t BufferInstance::GetXlaShape(xla::Shape** out_shape) {
// if (cached_shape_) {
// *out_shape = &(*cached_shape_);
// return iree_ok_status();
// }

// iree_hal_element_type_t hal_element_type =
// iree_hal_buffer_view_element_type(buffer_view());
// xla::PrimitiveType xla_element_type;
// IREE_RETURN_IF_ERROR(PJRTApiConverter::MapElementTypeToXlaElementType(
// hal_element_type, &xla_element_type));

// size_t rank = iree_hal_buffer_view_shape_rank(buffer_view());
// const iree_hal_dim_t* dims =
// iree_hal_buffer_view_shape_dims(buffer_view()); std::array<int64_t, 9>
// xla_dims; if (rank > xla_dims.size()) {
// return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
// "rank > 9 not supported");
// }
// for (size_t i = 0; i < rank; ++i) {
// xla_dims[i] = dims[i];
// }

// cached_shape_ = xla::ShapeUtil::MakeShape(
// xla_element_type,
// absl::MakeSpan(xla_dims.begin(), xla_dims.begin() + rank));
// *out_shape = &(*cached_shape_);
// return iree_ok_status();
// }

// TODO: Excise and convert directly to a C memory layout.
// iree_status_t BufferInstance::GetLayoutData(
// ::pjrt::BufferMemoryLayoutData** out_layout_data) {
// if (!cached_layout_data_) {
// xla::Shape* shape;
// IREE_RETURN_IF_ERROR(GetXlaShape(&shape));
// if (!shape->has_layout()) {
// return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
// "Buffer shape doesn't have a layout");
// }
// xla::StatusOr<::pjrt::BufferMemoryLayoutData> status_or =
// ::pjrt::ConvertToBufferMemoryLayoutData(shape->layout());
// if (!status_or.ok()) {
// return iree_make_status(IREE_STATUS_UNKNOWN,
// "Couldn't convert layout: %s",
// std::string(status_or.status().message()).data());
// }
// cached_layout_data_.emplace(std::move(*status_or));
// }
// *out_layout_data = &(*cached_layout_data_);
// return iree_ok_status();
// }

BufferInstance::BufferInstance(
DeviceInstance& device, iree::vm::ref<iree_hal_buffer_view_t> buffer_view)
: device_(device), buffer_view_(std::move(buffer_view)) {
IREE_CHECK_OK(device.CreateFence(&ready_fence_));
IREE_CHECK_OK(device.CreateFence(&done_fence_));

// Cache the dims.
size_t rank = iree_hal_buffer_view_shape_rank(buffer_view_.get());
const iree_hal_dim_t* dims =
iree_hal_buffer_view_shape_dims(buffer_view_.get());
dims_.resize(rank);
for (size_t i = 0; i < rank; ++i) {
dims_[i] = dims[i];
}
}

void BufferInstance::ComputeLayout() {
iree_hal_encoding_type_t encoding =
iree_hal_buffer_view_encoding_type(buffer_view_.get());
iree_hal_element_type_t element_type =
iree_hal_buffer_view_element_type(buffer_view_.get());

layout_.Reset();
if (encoding == IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR &&
iree_hal_element_is_byte_aligned(element_type)) {
// It is not documented, but PJRT only supports device buffers with a tiled
// layout.
layout_.InitializeDenseRowMajorTiled(dims_.size());
}
}

void BufferInstance::BindApi(PJRT_Api* api) {
Expand All @@ -353,40 +242,30 @@ void BufferInstance::BindApi(PJRT_Api* api) {
api->PJRT_Buffer_ElementType =
+[](PJRT_Buffer_ElementType_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Buffer_ElementType");
// TODO: Excise.
// BufferInstance* buffer = BufferInstance::Unwrap(args->buffer);
// auto impl = [&]() -> iree_status_t {
// // xla::Shape* shape;
// // TODO: don't use XLA shape at all
// // https://github.com/openxla/openxla-pjrt-plugin/issues/265
// IREE_RETURN_IF_ERROR(buffer->GetXlaShape(&shape));
// args->type = static_cast<PJRT_Buffer_Type>(shape->element_type());
// };
// return MakeError(impl());
return MakeError(
iree_make_status(IREE_STATUS_UNIMPLEMENTED, "PJRT_Buffer_ElementType"));
BufferInstance* buffer = BufferInstance::Unwrap(args->buffer);
auto element_type = buffer->element_type();
if (!element_type) {
return MakeError(iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"Unsupported PJRT buffer type"));
}
args->type = *element_type;
return nullptr;
};
api->PJRT_Buffer_Dimensions =
+[](PJRT_Buffer_Dimensions_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Buffer_Dimensions");
// auto impl = [&]() -> iree_status_t {
// BufferInstance* buffer = BufferInstance::Unwrap(args->buffer);
// xla::Shape* shape;
// // TODO: don't use XLA shape at all
// // https://github.com/openxla/openxla-pjrt-plugin/issues/265
// IREE_RETURN_IF_ERROR(buffer->GetXlaShape(&shape));
// args->dims = shape->dimensions().data();
// args->num_dims = shape->dimensions().size();
// return nullptr;
// };
return MakeError(
iree_make_status(IREE_STATUS_UNIMPLEMENTED, "PJRT_Buffer_Dimensions"));
BufferInstance* buffer = BufferInstance::Unwrap(args->buffer);
args->dims = buffer->dims();
args->num_dims = buffer->num_dims();
return nullptr;
};
api->PJRT_Buffer_UnpaddedDimensions =
+[](PJRT_Buffer_UnpaddedDimensions_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Buffer_UnpaddedDimensions");
return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"PJRT_Buffer_UnpaddedDimensions"));
BufferInstance* buffer = BufferInstance::Unwrap(args->buffer);
args->unpadded_dims = buffer->dims();
args->num_dims = buffer->num_dims();
return nullptr;
};
api->PJRT_Buffer_DynamicDimensionIndices =
+[](PJRT_Buffer_DynamicDimensionIndices_Args* args) -> PJRT_Error* {
Expand All @@ -397,18 +276,15 @@ void BufferInstance::BindApi(PJRT_Api* api) {
api->PJRT_Buffer_GetMemoryLayout =
+[](PJRT_Buffer_GetMemoryLayout_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Buffer_GetMemoryLayout");
// auto impl = [&]() -> iree_status_t {
// // TODO: Populate the C layout data directly from the buffer
// // instance.
// BufferInstance* buffer = BufferInstance::Unwrap(args->buffer);
// ::pjrt::BufferMemoryLayoutData* layout_data;
// IREE_RETURN_IF_ERROR(buffer->GetLayoutData(&layout_data));
// args->layout = layout_data->c_layout;
// return nullptr;
// };
// return MakeError(impl());
return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"PJRT_Buffer_GetMemoryLayout"));
BufferInstance* buffer = BufferInstance::Unwrap(args->buffer);
const PJRT_Buffer_MemoryLayout* layout = buffer->layout();
if (!layout) {
return MakeError(
iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"Unsupported PJRT layout for buffer view"));
}
args->layout = *layout;
return nullptr;
};
api->PJRT_Buffer_ToHostBuffer =
+[](PJRT_Buffer_ToHostBuffer_Args* args) -> PJRT_Error* {
Expand Down Expand Up @@ -659,6 +535,62 @@ iree_status_t BufferInstance::AdvanceDoneFence(iree_hal_semaphore_t* semaphore,
return IreeApi::hal_fence_insert(done_fence_.get(), semaphore, timepoint);
}

std::optional<PJRT_Buffer_Type> BufferInstance::element_type() {
iree_hal_element_type_t hal_element_type =
iree_hal_buffer_view_element_type(buffer_view());

// TODO: Cascade on bit-field sub-types to avoid large linear scan.
switch (hal_element_type) {
// TODO: How do I interpret signless?
case IREE_HAL_ELEMENT_TYPE_BOOL_8:
return PJRT_Buffer_Type_PRED;
case IREE_HAL_ELEMENT_TYPE_INT_4:
return PJRT_Buffer_Type_S4;
case IREE_HAL_ELEMENT_TYPE_INT_8:
return PJRT_Buffer_Type_S8;
case IREE_HAL_ELEMENT_TYPE_INT_16:
return PJRT_Buffer_Type_S16;
case IREE_HAL_ELEMENT_TYPE_INT_32:
return PJRT_Buffer_Type_S32;
case IREE_HAL_ELEMENT_TYPE_INT_64:
return PJRT_Buffer_Type_S64;
case IREE_HAL_ELEMENT_TYPE_SINT_4:
return PJRT_Buffer_Type_S4;
case IREE_HAL_ELEMENT_TYPE_SINT_8:
return PJRT_Buffer_Type_S8;
case IREE_HAL_ELEMENT_TYPE_SINT_16:
return PJRT_Buffer_Type_S16;
case IREE_HAL_ELEMENT_TYPE_SINT_32:
return PJRT_Buffer_Type_S32;
case IREE_HAL_ELEMENT_TYPE_SINT_64:
return PJRT_Buffer_Type_S64;
case IREE_HAL_ELEMENT_TYPE_UINT_4:
return PJRT_Buffer_Type_U4;
case IREE_HAL_ELEMENT_TYPE_UINT_8:
return PJRT_Buffer_Type_U8;
case IREE_HAL_ELEMENT_TYPE_UINT_16:
return PJRT_Buffer_Type_U16;
case IREE_HAL_ELEMENT_TYPE_UINT_32:
return PJRT_Buffer_Type_U32;
case IREE_HAL_ELEMENT_TYPE_UINT_64:
return PJRT_Buffer_Type_U64;
case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
return PJRT_Buffer_Type_F16;
case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
return PJRT_Buffer_Type_F32;
case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
return PJRT_Buffer_Type_F64;
case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
return PJRT_Buffer_Type_BF16;
case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64:
return PJRT_Buffer_Type_C64;
case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128:
return PJRT_Buffer_Type_C128;
default:
return {};
}
}

//===----------------------------------------------------------------------===//
// DeviceDescription
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit d76a104

Please sign in to comment.