From d76a104596d0d646fd2b2546f6cdfa811ceb6001 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 4 Oct 2023 17:26:30 -0700 Subject: [PATCH] Restore PJRT to minimally functional after break with XLA. (#15105) When imported, we removed all code that relied on static/bazel linkage into XLA internals. This patch adds back the pure C path for device memory layout, which is the minimum needed to run simple programs. Next, we will need to add code generation for layout changes, and then it is just minor ergonomic work to get back to where it was. As-is, though, we can run things with this patch again and can start wiring it for CI, etc. --- integrations/pjrt/README.md | 2 +- integrations/pjrt/requirements.txt | 1 - .../pjrt/src/iree_pjrt/common/CMakeLists.txt | 2 + .../pjrt/src/iree_pjrt/common/api_impl.cc | 278 +++++++----------- .../pjrt/src/iree_pjrt/common/api_impl.h | 38 ++- .../pjrt/src/iree_pjrt/common/layout_utils.cc | 47 +++ .../pjrt/src/iree_pjrt/common/layout_utils.h | 36 +++ 7 files changed, 213 insertions(+), 191 deletions(-) create mode 100644 integrations/pjrt/src/iree_pjrt/common/layout_utils.cc create mode 100644 integrations/pjrt/src/iree_pjrt/common/layout_utils.h diff --git a/integrations/pjrt/README.md b/integrations/pjrt/README.md index 78008acaf827..04f8aaaa4c2a 100644 --- a/integrations/pjrt/README.md +++ b/integrations/pjrt/README.md @@ -29,7 +29,7 @@ python -c "import jax; a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9]); print ## Install the plugin of your choice (in this example 'cpu') -pip install -e -v --no-deps python_packages/iree_cpu_plugin +pip install -v --no-deps -e python_packages/iree_cpu_plugin ## Verify basic functionality diff --git a/integrations/pjrt/requirements.txt b/integrations/pjrt/requirements.txt index 92804f923b2a..19853a92bbf3 100644 --- a/integrations/pjrt/requirements.txt +++ b/integrations/pjrt/requirements.txt @@ -2,4 +2,3 @@ -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html iree-compiler==20230922.653 jaxlib==0.4.17.dev20230922 --e ../jax diff --git a/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt b/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt index c10770383c36..46828f14ebba 100644 --- a/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt +++ b/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt @@ -11,10 +11,12 @@ iree_cc_library( "api_impl.h" "dylib_entry_point.cc.inc" "iree_helpers.h" + "layout_utils.h" "platform.h" "tensor_utils.h" SRCS "api_impl.cc" + "layout_utils.cc" "platform.cc" "tensor_utils.cc" DEPS diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc index 10228cd5694c..2c65500236b8 100644 --- a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc +++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc @@ -25,87 +25,6 @@ const std::string_view kMlirFormat = "mlir"; namespace PJRTApiConverter { namespace { -// Enum converter functions -// TODO: Remove once not using xla::Shape. -// iree_status_t MapElementTypeToXlaElementType( -// iree_hal_element_type_t element_type, xla::PrimitiveType* xla_primitive) -// { -// // TODO: Cascade on bit-field sub-types to avoid large linear scan. -// switch (element_type) { -// // TODO: How do I interpret signless? -// case IREE_HAL_ELEMENT_TYPE_BOOL_8: -// *xla_primitive = xla::PrimitiveType::PRED; -// return iree_ok_status(); -// case IREE_HAL_ELEMENT_TYPE_INT_4: -// *xla_primitive = xla::PrimitiveType::S4; -// return iree_ok_status(); -// case IREE_HAL_ELEMENT_TYPE_INT_8: -// *xla_primitive = xla::PrimitiveType::S8; -// return iree_ok_status(); -// case IREE_HAL_ELEMENT_TYPE_INT_16: -// *xla_primitive = xla::PrimitiveType::S16; -// return iree_ok_status(); -// case IREE_HAL_ELEMENT_TYPE_INT_32: -// *xla_primitive = xla::PrimitiveType::S32; -// return iree_ok_status(); -// case IREE_HAL_ELEMENT_TYPE_INT_64: -// *xla_primitive = xla::PrimitiveType::S64; -// return iree_ok_status(); -// case IREE_HAL_ELEMENT_TYPE_SINT_4: -// *xla_primitive = xla::PrimitiveType::S4; -// return iree_ok_status(); -// case IREE_HAL_ELEMENT_TYPE_SINT_8: -// *xla_primitive = xla::PrimitiveType::S8; -// return iree_ok_status(); -// case IREE_HAL_ELEMENT_TYPE_SINT_16: -// *xla_primitive = xla::PrimitiveType::S16; -// return iree_ok_status(); -// case IREE_HAL_ELEMENT_TYPE_SINT_32: -// *xla_primitive = xla::PrimitiveType::S32; -// return iree_ok_status(); -// case IREE_HAL_ELEMENT_TYPE_SINT_64: -// *xla_primitive = xla::PrimitiveType::S64; -// return iree_ok_status(); -// case IREE_HAL_ELEMENT_TYPE_UINT_4: -// *xla_primitive = xla::PrimitiveType::U4; -// return iree_ok_status(); -// case IREE_HAL_ELEMENT_TYPE_UINT_8: -// *xla_primitive = xla::PrimitiveType::U8; -// return iree_ok_status(); -// case IREE_HAL_ELEMENT_TYPE_UINT_16: -// *xla_primitive = xla::PrimitiveType::U16; -// return iree_ok_status(); -// case IREE_HAL_ELEMENT_TYPE_UINT_32: -// *xla_primitive = xla::PrimitiveType::U32; -// return iree_ok_status(); -// case IREE_HAL_ELEMENT_TYPE_UINT_64: -// *xla_primitive = xla::PrimitiveType::U64; -// return iree_ok_status(); -// case IREE_HAL_ELEMENT_TYPE_FLOAT_16: -// *xla_primitive = xla::PrimitiveType::F16; -// return iree_ok_status(); -// case IREE_HAL_ELEMENT_TYPE_FLOAT_32: -// *xla_primitive = xla::PrimitiveType::F32; -// return iree_ok_status(); -// case IREE_HAL_ELEMENT_TYPE_FLOAT_64: -// *xla_primitive = xla::PrimitiveType::F64; -// return iree_ok_status(); -// case IREE_HAL_ELEMENT_TYPE_BFLOAT_16: -// *xla_primitive = xla::PrimitiveType::BF16; -// return iree_ok_status(); -// case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64: -// *xla_primitive = xla::PrimitiveType::C64; -// return iree_ok_status(); -// case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128: -// *xla_primitive = xla::PrimitiveType::C128; -// return iree_ok_status(); -// default: -// return iree_make_status(IREE_STATUS_UNIMPLEMENTED, -// "conversion from unknown element type 0x%x", -// (int)element_type); -// } -// } - iree_status_t MapBufferTypeToElementType( PJRT_Buffer_Type buffer_type, iree_hal_element_type_t* element_type) { switch (buffer_type) { @@ -281,65 +200,35 @@ const std::string& ErrorInstance::message() const { BufferInstance::~BufferInstance() = default; -// TODO: Excise. -// iree_status_t BufferInstance::GetXlaShape(xla::Shape** out_shape) { -// if (cached_shape_) { -// *out_shape = &(*cached_shape_); -// return iree_ok_status(); -// } - -// iree_hal_element_type_t hal_element_type = -// iree_hal_buffer_view_element_type(buffer_view()); -// xla::PrimitiveType xla_element_type; -// IREE_RETURN_IF_ERROR(PJRTApiConverter::MapElementTypeToXlaElementType( -// hal_element_type, &xla_element_type)); - -// size_t rank = iree_hal_buffer_view_shape_rank(buffer_view()); -// const iree_hal_dim_t* dims = -// iree_hal_buffer_view_shape_dims(buffer_view()); std::array -// xla_dims; if (rank > xla_dims.size()) { -// return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, -// "rank > 9 not supported"); -// } -// for (size_t i = 0; i < rank; ++i) { -// xla_dims[i] = dims[i]; -// } - -// cached_shape_ = xla::ShapeUtil::MakeShape( -// xla_element_type, -// absl::MakeSpan(xla_dims.begin(), xla_dims.begin() + rank)); -// *out_shape = &(*cached_shape_); -// return iree_ok_status(); -// } - -// TODO: Excise and convert directly to a C memory layout. -// iree_status_t BufferInstance::GetLayoutData( -// ::pjrt::BufferMemoryLayoutData** out_layout_data) { -// if (!cached_layout_data_) { -// xla::Shape* shape; -// IREE_RETURN_IF_ERROR(GetXlaShape(&shape)); -// if (!shape->has_layout()) { -// return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, -// "Buffer shape doesn't have a layout"); -// } -// xla::StatusOr<::pjrt::BufferMemoryLayoutData> status_or = -// ::pjrt::ConvertToBufferMemoryLayoutData(shape->layout()); -// if (!status_or.ok()) { -// return iree_make_status(IREE_STATUS_UNKNOWN, -// "Couldn't convert layout: %s", -// std::string(status_or.status().message()).data()); -// } -// cached_layout_data_.emplace(std::move(*status_or)); -// } -// *out_layout_data = &(*cached_layout_data_); -// return iree_ok_status(); -// } - BufferInstance::BufferInstance( DeviceInstance& device, iree::vm::ref buffer_view) : device_(device), buffer_view_(std::move(buffer_view)) { IREE_CHECK_OK(device.CreateFence(&ready_fence_)); IREE_CHECK_OK(device.CreateFence(&done_fence_)); + + // Cache the dims. + size_t rank = iree_hal_buffer_view_shape_rank(buffer_view_.get()); + const iree_hal_dim_t* dims = + iree_hal_buffer_view_shape_dims(buffer_view_.get()); + dims_.resize(rank); + for (size_t i = 0; i < rank; ++i) { + dims_[i] = dims[i]; + } +} + +void BufferInstance::ComputeLayout() { + iree_hal_encoding_type_t encoding = + iree_hal_buffer_view_encoding_type(buffer_view_.get()); + iree_hal_element_type_t element_type = + iree_hal_buffer_view_element_type(buffer_view_.get()); + + layout_.Reset(); + if (encoding == IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR && + iree_hal_element_is_byte_aligned(element_type)) { + // It is not documented, but PJRT only supports device buffers with a tiled + // layout. + layout_.InitializeDenseRowMajorTiled(dims_.size()); + } } void BufferInstance::BindApi(PJRT_Api* api) { @@ -353,40 +242,30 @@ void BufferInstance::BindApi(PJRT_Api* api) { api->PJRT_Buffer_ElementType = +[](PJRT_Buffer_ElementType_Args* args) -> PJRT_Error* { IREE_TRACE_SCOPE_NAMED("PJRT_Buffer_ElementType"); - // TODO: Excise. - // BufferInstance* buffer = BufferInstance::Unwrap(args->buffer); - // auto impl = [&]() -> iree_status_t { - // // xla::Shape* shape; - // // TODO: don't use XLA shape at all - // // https://github.com/openxla/openxla-pjrt-plugin/issues/265 - // IREE_RETURN_IF_ERROR(buffer->GetXlaShape(&shape)); - // args->type = static_cast(shape->element_type()); - // }; - // return MakeError(impl()); - return MakeError( - iree_make_status(IREE_STATUS_UNIMPLEMENTED, "PJRT_Buffer_ElementType")); + BufferInstance* buffer = BufferInstance::Unwrap(args->buffer); + auto element_type = buffer->element_type(); + if (!element_type) { + return MakeError(iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "Unsupported PJRT buffer type")); + } + args->type = *element_type; + return nullptr; }; api->PJRT_Buffer_Dimensions = +[](PJRT_Buffer_Dimensions_Args* args) -> PJRT_Error* { IREE_TRACE_SCOPE_NAMED("PJRT_Buffer_Dimensions"); - // auto impl = [&]() -> iree_status_t { - // BufferInstance* buffer = BufferInstance::Unwrap(args->buffer); - // xla::Shape* shape; - // // TODO: don't use XLA shape at all - // // https://github.com/openxla/openxla-pjrt-plugin/issues/265 - // IREE_RETURN_IF_ERROR(buffer->GetXlaShape(&shape)); - // args->dims = shape->dimensions().data(); - // args->num_dims = shape->dimensions().size(); - // return nullptr; - // }; - return MakeError( - iree_make_status(IREE_STATUS_UNIMPLEMENTED, "PJRT_Buffer_Dimensions")); + BufferInstance* buffer = BufferInstance::Unwrap(args->buffer); + args->dims = buffer->dims(); + args->num_dims = buffer->num_dims(); + return nullptr; }; api->PJRT_Buffer_UnpaddedDimensions = +[](PJRT_Buffer_UnpaddedDimensions_Args* args) -> PJRT_Error* { IREE_TRACE_SCOPE_NAMED("PJRT_Buffer_UnpaddedDimensions"); - return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED, - "PJRT_Buffer_UnpaddedDimensions")); + BufferInstance* buffer = BufferInstance::Unwrap(args->buffer); + args->unpadded_dims = buffer->dims(); + args->num_dims = buffer->num_dims(); + return nullptr; }; api->PJRT_Buffer_DynamicDimensionIndices = +[](PJRT_Buffer_DynamicDimensionIndices_Args* args) -> PJRT_Error* { @@ -397,18 +276,15 @@ void BufferInstance::BindApi(PJRT_Api* api) { api->PJRT_Buffer_GetMemoryLayout = +[](PJRT_Buffer_GetMemoryLayout_Args* args) -> PJRT_Error* { IREE_TRACE_SCOPE_NAMED("PJRT_Buffer_GetMemoryLayout"); - // auto impl = [&]() -> iree_status_t { - // // TODO: Populate the C layout data directly from the buffer - // // instance. - // BufferInstance* buffer = BufferInstance::Unwrap(args->buffer); - // ::pjrt::BufferMemoryLayoutData* layout_data; - // IREE_RETURN_IF_ERROR(buffer->GetLayoutData(&layout_data)); - // args->layout = layout_data->c_layout; - // return nullptr; - // }; - // return MakeError(impl()); - return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED, - "PJRT_Buffer_GetMemoryLayout")); + BufferInstance* buffer = BufferInstance::Unwrap(args->buffer); + const PJRT_Buffer_MemoryLayout* layout = buffer->layout(); + if (!layout) { + return MakeError( + iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "Unsupported PJRT layout for buffer view")); + } + args->layout = *layout; + return nullptr; }; api->PJRT_Buffer_ToHostBuffer = +[](PJRT_Buffer_ToHostBuffer_Args* args) -> PJRT_Error* { @@ -659,6 +535,62 @@ iree_status_t BufferInstance::AdvanceDoneFence(iree_hal_semaphore_t* semaphore, return IreeApi::hal_fence_insert(done_fence_.get(), semaphore, timepoint); } +std::optional BufferInstance::element_type() { + iree_hal_element_type_t hal_element_type = + iree_hal_buffer_view_element_type(buffer_view()); + + // TODO: Cascade on bit-field sub-types to avoid large linear scan. + switch (hal_element_type) { + // TODO: How do I interpret signless? + case IREE_HAL_ELEMENT_TYPE_BOOL_8: + return PJRT_Buffer_Type_PRED; + case IREE_HAL_ELEMENT_TYPE_INT_4: + return PJRT_Buffer_Type_S4; + case IREE_HAL_ELEMENT_TYPE_INT_8: + return PJRT_Buffer_Type_S8; + case IREE_HAL_ELEMENT_TYPE_INT_16: + return PJRT_Buffer_Type_S16; + case IREE_HAL_ELEMENT_TYPE_INT_32: + return PJRT_Buffer_Type_S32; + case IREE_HAL_ELEMENT_TYPE_INT_64: + return PJRT_Buffer_Type_S64; + case IREE_HAL_ELEMENT_TYPE_SINT_4: + return PJRT_Buffer_Type_S4; + case IREE_HAL_ELEMENT_TYPE_SINT_8: + return PJRT_Buffer_Type_S8; + case IREE_HAL_ELEMENT_TYPE_SINT_16: + return PJRT_Buffer_Type_S16; + case IREE_HAL_ELEMENT_TYPE_SINT_32: + return PJRT_Buffer_Type_S32; + case IREE_HAL_ELEMENT_TYPE_SINT_64: + return PJRT_Buffer_Type_S64; + case IREE_HAL_ELEMENT_TYPE_UINT_4: + return PJRT_Buffer_Type_U4; + case IREE_HAL_ELEMENT_TYPE_UINT_8: + return PJRT_Buffer_Type_U8; + case IREE_HAL_ELEMENT_TYPE_UINT_16: + return PJRT_Buffer_Type_U16; + case IREE_HAL_ELEMENT_TYPE_UINT_32: + return PJRT_Buffer_Type_U32; + case IREE_HAL_ELEMENT_TYPE_UINT_64: + return PJRT_Buffer_Type_U64; + case IREE_HAL_ELEMENT_TYPE_FLOAT_16: + return PJRT_Buffer_Type_F16; + case IREE_HAL_ELEMENT_TYPE_FLOAT_32: + return PJRT_Buffer_Type_F32; + case IREE_HAL_ELEMENT_TYPE_FLOAT_64: + return PJRT_Buffer_Type_F64; + case IREE_HAL_ELEMENT_TYPE_BFLOAT_16: + return PJRT_Buffer_Type_BF16; + case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64: + return PJRT_Buffer_Type_C64; + case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128: + return PJRT_Buffer_Type_C128; + default: + return {}; + } +} + //===----------------------------------------------------------------------===// // DeviceDescription //===----------------------------------------------------------------------===// diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.h b/integrations/pjrt/src/iree_pjrt/common/api_impl.h index 3edcc2213e20..d56a0bd3de1a 100644 --- a/integrations/pjrt/src/iree_pjrt/common/api_impl.h +++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -21,14 +22,9 @@ #include "iree/vm/api.h" #include "iree/vm/bytecode/module.h" #include "iree_pjrt/common/compiler.h" +#include "iree_pjrt/common/layout_utils.h" #include "iree_pjrt/common/platform.h" #include "xla/pjrt/c/pjrt_c_api.h" -// TODO: Excise. Various deep dependencies on XLA internals. -// #include "xla/pjrt/c/pjrt_c_api_helpers.h" -// TODO: Excise. Various deep dependencies on XLA internals. -// #include "xla/pjrt/pjrt_executable.h" -// TODO: Excise. Various deep dependencies on XLA internals. -// #include "xla/shape_util.h" namespace iree::pjrt { @@ -95,11 +91,6 @@ class BufferInstance { // the hook to get an unsafe pointer (avoids a copy). return false; } - // TODO: Excise. - // iree_status_t GetXlaShape(xla::Shape** out_shape); - // TODO: Excise. - // iree_status_t GetLayoutData(::pjrt::BufferMemoryLayoutData** - // out_layout_data); // Gets the required host size in bytes to copy to host. iree_status_t GetHostSizeInBytes(iree_host_size_t* host_size); @@ -115,14 +106,25 @@ class BufferInstance { iree_hal_fence_t* ready_fence() { return ready_fence_.get(); } iree_hal_fence_t* done_fence() { return done_fence_.get(); } + const int64_t* dims() { return dims_.data(); } + size_t num_dims() { return dims_.size(); } + std::optional element_type(); + const PJRT_Buffer_MemoryLayout* layout() { + if (!layout_.is_valid()) { + ComputeLayout(); + } + if (layout_.is_valid()) { + return &layout_.c_layout(); + } else { + return nullptr; + } + } + private: + void ComputeLayout(); + DeviceInstance& device_; iree::vm::ref buffer_view_; - // Various things require XLA's idea of shapes, layouts, etc. - // We keep one around for such cases. - // TODO: Excise. - // std::optional cached_shape_; - // std::optional<::pjrt::BufferMemoryLayoutData> cached_layout_data_; // When the buffer resource gets freed, this is set to true. bool is_deleted_ = false; // Fences. @@ -132,6 +134,10 @@ class BufferInstance { // Consumers should advance this fence when using it. iree::vm::ref ready_fence_; iree::vm::ref done_fence_; + + // API elements that must have the same lifetime as BufferInstance. + std::vector dims_; + ApiMemoryLayout layout_; }; //===----------------------------------------------------------------------===// diff --git a/integrations/pjrt/src/iree_pjrt/common/layout_utils.cc b/integrations/pjrt/src/iree_pjrt/common/layout_utils.cc new file mode 100644 index 000000000000..cc5c5828fdab --- /dev/null +++ b/integrations/pjrt/src/iree_pjrt/common/layout_utils.cc @@ -0,0 +1,47 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree_pjrt/common/layout_utils.h" + +#include + +namespace iree::pjrt { + +void ApiMemoryLayout::InitializeDenseRowMajorStrided(size_t rank, + const int64_t *dims, + size_t unit_stride_bytes) { + memset(&c_layout_, 0, sizeof(c_layout_)); + int64_t stride = unit_stride_bytes; + storage1_.resize(rank); + for (size_t pos = 0; pos < rank; ++pos) { + storage1_[rank - pos - 1] = stride; + stride *= dims[pos]; + } + + c_layout_.struct_size = sizeof(c_layout_); + c_layout_.type = PJRT_Buffer_MemoryLayout_Type_Strides; + c_layout_.strides.struct_size = sizeof(c_layout_.strides); + c_layout_.strides.byte_strides = storage1_.data(); + c_layout_.strides.num_byte_strides = storage1_.size(); + valid_ = true; +} + +void ApiMemoryLayout::InitializeDenseRowMajorTiled(int64_t rank) { + memset(&c_layout_, 0, sizeof(c_layout_)); + // Set minor_to_major. See SetDefaultLayoutToContainer in LayoutUtil.h + storage1_.resize(rank, 0); + for (int64_t i = 0; i < rank; ++i) { + storage1_[i] = rank - 1 - i; + } + c_layout_.struct_size = sizeof(c_layout_); + c_layout_.type = PJRT_Buffer_MemoryLayout_Type_Tiled; + c_layout_.tiled.struct_size = sizeof(c_layout_.tiled); + c_layout_.tiled.minor_to_major = storage1_.data(); + c_layout_.tiled.minor_to_major_size = storage1_.size(); + valid_ = true; +} + +} // namespace iree::pjrt diff --git a/integrations/pjrt/src/iree_pjrt/common/layout_utils.h b/integrations/pjrt/src/iree_pjrt/common/layout_utils.h new file mode 100644 index 000000000000..8402807888b4 --- /dev/null +++ b/integrations/pjrt/src/iree_pjrt/common/layout_utils.h @@ -0,0 +1,36 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include "iree/hal/buffer_view.h" +#include "xla/pjrt/c/pjrt_c_api.h" + +namespace iree::pjrt { + +class ApiMemoryLayout { + public: + ApiMemoryLayout() = default; + + void InitializeDenseRowMajorStrided(size_t rank, const int64_t *dims, + size_t unit_stride_bytes); + void InitializeDenseRowMajorTiled(int64_t rank); + void Reset() { valid_ = false; } + + bool is_valid() const { return valid_; } + const PJRT_Buffer_MemoryLayout &c_layout() const { return c_layout_; }; + + private: + PJRT_Buffer_MemoryLayout c_layout_; + + // Retained vector of ints. + std::vector storage1_; + + bool valid_ = false; +}; + +} // namespace iree::pjrt