Skip to content

Commit

Permalink
[XLA:GPU] Fix bug in slice matching logic for dynamic-slice-fusions.
Browse files Browse the repository at this point in the history
This should resolve the issue at jax-ml/jax#23854.
I took advantage of working on this fix to try to document the expectations and
logic a little more.

PiperOrigin-RevId: 687271334
  • Loading branch information
bchetioui authored and tensorflower-gardener committed Oct 18, 2024
1 parent b4d04a9 commit 5cb7826
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 12 deletions.
135 changes: 123 additions & 12 deletions third_party/xla/xla/service/gpu/fusions/custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.

#include <cstddef>
#include <cstdint>
#include <iterator>
#include <memory>
#include <optional>
#include <string>
Expand All @@ -27,6 +28,7 @@ limitations under the License.
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/AsmParser/AsmParser.h"
Expand Down Expand Up @@ -251,26 +253,86 @@ absl::Status CollectSliceInfo(
return absl::OkStatus();
}

// This function assumes that the computation graph for `fusion_instr` looks
// like:
//
// ...
// root_tuple_operand = (... ty[shape], ...) ...
// ROOT root_tuple = (... (... ty[shape], ...), ...)
// tuple(... root_tuple_operand, ...)
//
// Given such a pattern and a (complete) index into `root_tuple_operand`, we
// recover the slice of `root_tuple` that corresponds to that index.
absl::StatusOr<BufferAllocation::Slice> GetResultSliceForPartiallyUnnestedTuple(
const BufferAssignment& buffer_assignment,
const HloFusionInstruction& fusion_instr,
const HloInstruction& root_tuple_operand,
const ShapeIndex& root_tuple_operand_shape_idx,
const HloInstruction& root_tuple) {
int64_t operand_index = root_tuple.operand_index(&root_tuple_operand);
ShapeIndex slice_shape_index;
slice_shape_index.push_back(operand_index);
absl::c_copy(root_tuple_operand_shape_idx,
std::back_inserter(slice_shape_index));
return GetAllocationSlice(buffer_assignment, &fusion_instr,
slice_shape_index);
}

absl::StatusOr<BufferAllocation::Slice> GetResultSlice(
const BufferAssignment& buffer_assignment, const HloFusionAdaptor& adaptor,
const HloInstruction& fusion_instr, const HloInstruction& start_instr,
const HloFusionInstruction& fusion_instr, const HloInstruction& start_instr,
std::vector<HloInstruction*>& slice_instrs, const ShapeIndex& shape_idx,
unsigned arg_idx) {
auto* start = const_cast<HloInstruction*>(&start_instr);
if (start->IsRoot()) {
return GetAllocationSlice(buffer_assignment, &fusion_instr, shape_idx);
}

// Walk through ShapeIndex to find the real "user" (i.e. not get-tuple-element
// user). Otherwise one sliced element will mark all buffers of all other
// elements "sliced" too.
if (start->shape().IsTuple()) {
for (auto idx : shape_idx) {
std::vector<HloGetTupleElementInstruction*> gte_users(
start->shape().tuple_shapes_size(), nullptr);
for (auto* user : start->users())
if (auto* gte = DynCast<HloGetTupleElementInstruction>(user))
gte_users[gte->tuple_index()] = gte;

start = static_cast<HloInstruction*>(gte_users[idx]);
if (start == nullptr)
return GetAllocationSlice(buffer_assignment, &fusion_instr, shape_idx);
for (auto [index_nesting_level, index_in_shape] :
llvm::enumerate(shape_idx)) {
HloInstruction* gte_user = nullptr;
for (auto* user : start->users()) {
if (user->opcode() == HloOpcode::kGetTupleElement &&
user->tuple_index() == index_in_shape) {
gte_user = user;
break;
}
}

if (gte_user == nullptr) {
// At this point, two things are known:
// 1. `start` was not the root instruction of the fusion at the
// beginning of this function call;
// 2. `start` still has a tuple shape because we haven't managed to
// unwrap the entire shape index.
// We also know, by definition of the surrounding pass, that all the
// results of the custom call must be materialized at the output of
// the fusion, which indicates that `start` is currently *not* the
// root. Since we can't slice/bitcast/reshape a tuple, then the
// only possible consumer should be a `tuple` instruction, which
// logically should be the root of the fusion.
HloInstruction* start_user = start->users().front();
if (start->user_count() != 1 ||
start_user->opcode() != HloOpcode::kTuple ||
!start_user->IsRoot()) {
return absl::InternalError(
"Expected the user of a nested tuple shape to be a root tuple "
"instruction."
"Expected a single user of the tuple-shaped instruction");
}

ShapeIndex remaining_shape_index(
shape_idx.begin() + index_nesting_level, shape_idx.end());
return GetResultSliceForPartiallyUnnestedTuple(
buffer_assignment, fusion_instr, *start, remaining_shape_index,
*start_user);
}

start = gte_user;
}
}

Expand All @@ -296,7 +358,56 @@ absl::StatusOr<BufferAllocation::Slice> GetResultSlice(
}
}

return GetAllocationSlice(buffer_assignment, &fusion_instr, shape_idx);
constexpr absl::string_view kNonContiguousDynamicUpdateSliceError =
"DynamicSliceFusion only handles contiguous slices currently";

// At this point, we've fully unfolded a tuple that was not the root of the
// computation. There are two options; either, the root is a tuple, or it is
// not.
//
// If the root is not a tuple, we can simply get the buffer slice assigned to
// the fusion itself---there is nothing else to choose from.
if (fusion_instr.shape().IsArray()) {
HloInstruction* root = fusion_instr.fused_expression_root();
if (root->opcode() == HloOpcode::kDynamicUpdateSlice &&
!IsContiguousSlice(*root)) {
return absl::InternalError(kNonContiguousDynamicUpdateSliceError);
}
return GetAllocationSlice(buffer_assignment, &fusion_instr, {});
}

// If the root is a tuple however, it may be a nested tuple. Go all the way
// to the root to figure out the index that our array occupies within that
// tuple.
HloInstruction* current_hlo = start;
std::vector<int64_t> reversed_shape_index;
do {
TF_RET_CHECK(current_hlo->user_count() == 1);
HloInstruction* user = current_hlo->users().front();
// We may encounter three ops here: dynamic-update-slice, tuple, or bitcast.
switch (user->opcode()) {
case HloOpcode::kBitcast:
break;
case HloOpcode::kDynamicUpdateSlice:
if (!IsContiguousSlice(*user)) {
return absl::InternalError(kNonContiguousDynamicUpdateSliceError);
}
break;
case HloOpcode::kTuple:
reversed_shape_index.push_back(user->operand_index(current_hlo));
break;
default:
return absl::InternalError(
absl::StrCat("Unexpected opcode while processing the epilogue of a "
"DynamicSliceFusion: ",
HloOpcodeString(user->opcode())));
};
current_hlo = user;
} while (!current_hlo->IsRoot());

return GetAllocationSlice(
buffer_assignment, &fusion_instr,
ShapeIndex(reversed_shape_index.rbegin(), reversed_shape_index.rend()));
}

absl::StatusOr<FusionEmissionResult> EmitGemm(
Expand Down
10 changes: 10 additions & 0 deletions third_party/xla/xla/service/gpu/fusions/custom.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ class CustomFusion : public FusionInterface {
// compile-time instead of allocating a new buffer for it at runtime by
// translating the static slice into offset + size of the original buffer passed
// into the custom call `%gemm`.
//
// It is possible to inscribe the results of the custom call within a larger
// array. In that case, the affected results are each fed into a
// `dynamic-update-slice` operation, whose result is one of the fusion's
// outputs.
//
// The pass makes the assumption that, for each one of the custom-call's outputs
// there is exactly one path to the fusion root. The resulting shape for the
// dynamic slice fusion may be an unwrapped array, a flat tuple, or even a
// nested tuple.
class DynamicSliceFusion : public FusionInterface {
public:
explicit DynamicSliceFusion(const HloFusionAnalysis& analysis)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,94 @@ TEST_F(DynamicSliceFusionTest, CublasGemmWithWorkspace) {
/*run_hlo_passes=*/false));
}

TEST_F(DynamicSliceFusionTest, NestedTupleOutputForCublasGemmWithWorkspace) {
ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3};

const char* hlo_ref = R"(
HloModule nested_tuple
ENTRY main {
p0 = f16[2,8,8]{2,1,0} parameter(0)
p1 = f16[2,8,8]{2,1,0} parameter(1)
slice_1 = f16[1,8,8]{2,1,0} slice(p0), slice={[1:2], [0:8], [0:8]}
bitcast_1 = f16[8,8]{1,0} bitcast(slice_1)
slice_2 = f16[1,8,8]{2,1,0} slice(p1), slice={[1:2], [0:8], [0:8]}
bitcast_2 = f16[8,8]{1,0} bitcast(slice_2)
custom-call = (f16[8,8]{1,0}, s8[256]{0}) custom-call(bitcast_1, bitcast_2),
custom_call_target="__cublas$gemm",
backend_config={"gemm_backend_config":{
"alpha_real":1,
"beta":0,
"dot_dimension_numbers":{
"lhs_contracting_dimensions":["1"],
"rhs_contracting_dimensions":["0"],
"lhs_batch_dimensions":[],
"rhs_batch_dimensions":[]
},
"alpha_imag":0,
"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
"epilogue":"DEFAULT",
"lhs_stride":"64",
"rhs_stride":"64",
"grad_x":false,
"grad_y":false
}}
result = f16[8,8]{1,0} get-tuple-element(custom-call), index=0
workspace = s8[256]{0} get-tuple-element(custom-call), index=1
nested_tuple = (s8[256]{0}) tuple(workspace)
ROOT tuple = (f16[8,8]{1,0}, (s8[256]{0})) tuple(result, nested_tuple)
})";

const char* hlo_opt = R"(
HloModule jit_slice
fused_computation {
p0 = f16[2,8,8]{2,1,0} parameter(0)
p1 = f16[2,8,8]{2,1,0} parameter(1)
slice_1 = f16[1,8,8]{2,1,0} slice(p0), slice={[1:2], [0:8], [0:8]}
bitcast_1 = f16[8,8]{1,0} bitcast(slice_1)
slice_2 = f16[1,8,8]{2,1,0} slice(p1), slice={[1:2], [0:8], [0:8]}
bitcast_2 = f16[8,8]{1,0} bitcast(slice_2)
custom-call = (f16[8,8]{1,0}, s8[256]{0}) custom-call(bitcast_1, bitcast_2),
custom_call_target="__cublas$gemm",
backend_config={"gemm_backend_config":{
"alpha_real":1,
"beta":0,
"dot_dimension_numbers":{
"lhs_contracting_dimensions":["1"],
"rhs_contracting_dimensions":["0"],
"lhs_batch_dimensions":[],
"rhs_batch_dimensions":[]
},
"alpha_imag":0,
"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
"epilogue":"DEFAULT",
"lhs_stride":"64",
"rhs_stride":"64",
"grad_x":false,
"grad_y":false
}}
result = f16[8,8]{1,0} get-tuple-element(custom-call), index=0
workspace = s8[256]{0} get-tuple-element(custom-call), index=1
nested_tuple = (s8[256]{0}) tuple(workspace)
ROOT tuple = (f16[8,8]{1,0}, (s8[256]{0})) tuple(result, nested_tuple)
}
ENTRY main.9 {
p0 = f16[2,8,8]{2,1,0} parameter(0)
p1 = f16[2,8,8]{2,1,0} parameter(1)
ROOT fusion = (f16[8,8]{1,0}, (s8[256]{0})) fusion(p0, p1), kind=kCustom, calls=fused_computation,
backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}}
})";

EXPECT_TRUE(RunAndCompareTwoModules(
hlo_ref, hlo_opt, GetModuleConfigWithDeterministicOps(),
GetModuleConfigWithDeterministicOps(), error_spec,
/*run_hlo_passes=*/false));
}

TEST_F(DynamicSliceFusionTest, ContiguousSlice) {
ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3};

Expand Down

0 comments on commit 5cb7826

Please sign in to comment.