Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Atomic Load and Store operations for Triton (tl.atomic_store/tl.atomic_load) #5187

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 36 additions & 10 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def TT_LoadOp : TT_Op<"load", [

DefaultValuedAttr<DenseI32ArrayAttr, "::llvm::ArrayRef<int32_t>{}">:$boundaryCheck,
OptionalAttr<TT_PaddingOptionAttr>:$padding,
OptionalAttr<TT_MemSemanticAttr>:$sem, OptionalAttr<TT_MemSyncScopeAttr>:$scope,
DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
DefaultValuedAttr<TT_EvictionPolicyAttr, "::mlir::triton::EvictionPolicy::NORMAL">:$evict,
DefaultValuedAttr<BoolAttr, "false">:$isVolatile
Expand All @@ -263,22 +264,35 @@ def TT_LoadOp : TT_Op<"load", [

let builders = [
// A tensor of pointers or a pointer to a scalar
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
OpBuilder<(ins "Value":$ptr,
"std::optional<triton::MemSemantic>":$sem,
"std::optional<triton::MemSyncScope>":$scope,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
// A tensor pointer with boundary check and padding
OpBuilder<(ins "Value":$ptr, "ArrayRef<int32_t>":$boundaryCheck,
"std::optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
"std::optional<triton::PaddingOption>":$padding,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
// A tensor of pointers or a pointer to a scalar with mask
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
OpBuilder<(ins "Value":$ptr, "Value":$mask,
"std::optional<triton::MemSemantic>":$sem,
"std::optional<triton::MemSyncScope>":$scope,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
// A tensor of pointers or a pointer to a scalar with mask and other
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache,
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other,
"std::optional<triton::MemSemantic>":$sem,
"std::optional<triton::MemSyncScope>":$scope,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
// A utility function to build the operation with all attributes
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other,
"ArrayRef<int32_t>":$boundaryCheck,
"std::optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
"std::optional<triton::PaddingOption>":$padding,
"std::optional<triton::MemSemantic>":$sem,
"std::optional<triton::MemSyncScope>":$scope,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
];

Expand All @@ -299,7 +313,9 @@ def TT_LoadOp : TT_Op<"load", [
$ptr (`,` $mask^)? (`,` $other^)?
oilist(
`cacheModifier` `=` $cache |
`evictionPolicy` `=` $evict
`evictionPolicy` `=` $evict |
`memSemantic` `=` $sem |
`memSyncScope` `=` $scope
)
attr-dict `:` type($ptr)
}];
Expand All @@ -325,18 +341,26 @@ def TT_StoreOp : TT_Op<"store", [
TT_Type:$value,
Optional<TT_BoolLike>:$mask,
DefaultValuedAttr<DenseI32ArrayAttr, "::llvm::ArrayRef<int32_t>{}">:$boundaryCheck,
OptionalAttr<TT_MemSemanticAttr>:$sem, OptionalAttr<TT_MemSyncScopeAttr>:$scope,
DefaultValuedAttr<TT_CacheModifierAttr, "triton::CacheModifier::NONE">:$cache,
DefaultValuedAttr<TT_EvictionPolicyAttr, "triton::EvictionPolicy::NORMAL">:$evict
);

let builders = [
// A tensor of pointers or a pointer to a scalar
OpBuilder<(ins "Value":$ptr, "Value":$value, "triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict)>,
OpBuilder<(ins "Value":$ptr, "Value":$value,
"std::optional<triton::MemSemantic>":$sem,
"std::optional<triton::MemSyncScope>":$scope,
"triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict)>,
// A tensor of pointers or a pointer to a scalar with mask
OpBuilder<(ins "Value":$ptr, "Value":$value, "Value":$mask, "triton::CacheModifier":$cache,
OpBuilder<(ins "Value":$ptr, "Value":$value, "Value":$mask,
"std::optional<triton::MemSemantic>":$sem,
"std::optional<triton::MemSyncScope>":$scope,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict)>,
// A tensor pointer with boundary check
OpBuilder<(ins "Value":$ptr, "Value":$value, "ArrayRef<int32_t>":$boundaryCheck, "triton::CacheModifier":$cache,
OpBuilder<(ins "Value":$ptr, "Value":$value, "ArrayRef<int32_t>":$boundaryCheck,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict)>
];

Expand All @@ -348,7 +372,9 @@ def TT_StoreOp : TT_Op<"store", [
// due to limitations in MLIR's asm parser.
let assemblyFormat = [{
$ptr `,` $value (`,` $mask^)?
oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict)
oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict |
`memSemantic` `=` $sem | `memSyncScope` `=` $scope
)
attr-dict `:` type($ptr)
}];

Expand Down
75 changes: 55 additions & 20 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,47 +37,61 @@ namespace triton {

//-- LoadOp --
void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
CacheModifier cache, EvictionPolicy evict, bool isVolatile) {
std::optional<MemSemantic> sem,
std::optional<MemSyncScope> scope, CacheModifier cache,
EvictionPolicy evict, bool isVolatile) {
LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{},
/*boundaryCheck=*/ArrayRef<int32_t>{}, /*padding=*/std::nullopt,
cache, evict, isVolatile);
sem, scope, cache, evict, isVolatile);
}

void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
ArrayRef<int32_t> boundaryCheck,
std::optional<PaddingOption> padding, CacheModifier cache,
EvictionPolicy evict, bool isVolatile) {
LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, boundaryCheck,
padding, cache, evict, isVolatile);
padding, std::nullopt, std::nullopt, cache, evict, isVolatile);
}

void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
Value mask, CacheModifier cache, EvictionPolicy evict,
bool isVolatile) {
Value mask, std::optional<MemSemantic> sem,
std::optional<MemSyncScope> scope, CacheModifier cache,
EvictionPolicy evict, bool isVolatile) {
LoadOp::build(builder, state, ptr, mask, /*other=*/{},
/*boundaryCheck=*/ArrayRef<int32_t>{},
/*padding=*/std::nullopt, cache, evict, isVolatile);
/*padding=*/std::nullopt, sem, scope, cache, evict, isVolatile);
}

void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
Value mask, Value other, CacheModifier cache,
Value mask, Value other, std::optional<MemSemantic> sem,
std::optional<MemSyncScope> scope, CacheModifier cache,
EvictionPolicy evict, bool isVolatile) {
LoadOp::build(builder, state, ptr, mask, other,
/*boundaryCheck=*/ArrayRef<int32_t>{},
/*padding=*/std::nullopt, cache, evict, isVolatile);
/*padding=*/std::nullopt, sem, scope, cache, evict, isVolatile);
}

void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
Value mask, Value other, ArrayRef<int32_t> boundaryCheck,
std::optional<PaddingOption> padding, CacheModifier cache,
std::optional<PaddingOption> padding,
std::optional<MemSemantic> sem,
std::optional<MemSyncScope> scope, CacheModifier cache,
EvictionPolicy evict, bool isVolatile) {
auto paddingAttr =
padding.has_value()
? PaddingOptionAttr::get(builder.getContext(), padding.value())
: PaddingOptionAttr();
auto semAttr = sem.has_value()
? MemSemanticAttr::get(builder.getContext(), sem.value())
: MemSemanticAttr();
auto scopeAttr =
scope.has_value()
? MemSyncScopeAttr::get(builder.getContext(), scope.value())
: MemSyncScopeAttr();

LoadOp::build(builder, state, ptr, mask, other,
builder.getDenseI32ArrayAttr(boundaryCheck), paddingAttr, cache,
evict, isVolatile);
builder.getDenseI32ArrayAttr(boundaryCheck), paddingAttr,
semAttr, scopeAttr, cache, evict, isVolatile);
}

// load(ptr, splat(1), ...) -> load(ptr, ...)
Expand Down Expand Up @@ -105,7 +119,8 @@ struct CanonicalizeMaskedLoadPattern : public OpRewritePattern<LoadOp> {
rewriter.replaceOpWithNewOp<LoadOp>(
loadOp, loadOp.getType(), loadOp.getPtr(), Value(), Value(),
loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(),
loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile());
loadOp.getSemAttr(), loadOp.getScopeAttr(), loadOp.getCache(),
loadOp.getEvict(), loadOp.getIsVolatile());
} else {
// mask = splat(0)

Expand All @@ -127,24 +142,44 @@ void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,

//-- StoreOp --
void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr,
Value value, CacheModifier cache, EvictionPolicy evict) {
Value value, std::optional<MemSemantic> sem,
std::optional<MemSyncScope> scope, CacheModifier cache,
EvictionPolicy evict) {
auto semAttr = sem.has_value()
? MemSemanticAttr::get(builder.getContext(), sem.value())
: MemSemanticAttr();
auto scopeAttr =
scope.has_value()
? MemSyncScopeAttr::get(builder.getContext(), scope.value())
: MemSyncScopeAttr();
return StoreOp::build(builder, state, ptr, value, /*mask=*/{},
/*boundaryCheck=*/{}, cache, evict);
/*boundaryCheck=*/{}, semAttr, scopeAttr, cache, evict);
}

void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr,
Value value, Value mask, CacheModifier cache,
Value value, Value mask, std::optional<MemSemantic> sem,
std::optional<MemSyncScope> scope, CacheModifier cache,
EvictionPolicy evict) {
auto semAttr = sem.has_value()
? MemSemanticAttr::get(builder.getContext(), sem.value())
: MemSemanticAttr();
auto scopeAttr =
scope.has_value()
? MemSyncScopeAttr::get(builder.getContext(), scope.value())
: MemSyncScopeAttr();

return StoreOp::build(builder, state, ptr, value, mask, /*boundaryCheck=*/{},
cache, evict);
semAttr, scopeAttr, cache, evict);
}

void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr,
Value value, ArrayRef<int32_t> boundaryCheck,
CacheModifier cache, EvictionPolicy evict) {
auto semAttr = MemSemanticAttr();
auto scopeAttr = MemSyncScopeAttr();
return StoreOp::build(builder, state, ptr, value, /*mask=*/{},
builder.getDenseI32ArrayAttr(boundaryCheck), cache,
evict);
builder.getDenseI32ArrayAttr(boundaryCheck), semAttr,
scopeAttr, cache, evict);
}

// store(ptr, value, splat(1), ...) -> store(ptr, value, ...)
Expand All @@ -170,8 +205,8 @@ struct CanonicalizeMaskedStorePattern : public OpRewritePattern<StoreOp> {
if (splatMask.getSplatValue<IntegerAttr>().getValue() == true) {
// mask = splat(1)
rewriter.replaceOpWithNewOp<StoreOp>(
storeOp, storeOp.getPtr(), storeOp.getValue(), storeOp.getCache(),
storeOp.getEvict());
storeOp, storeOp.getPtr(), storeOp.getValue(), storeOp.getSem(),
storeOp.getScope(), storeOp.getCache(), storeOp.getEvict());
} else {
// mask = splat(0)
rewriter.eraseOp(storeOp);
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/Triton/Transforms/Combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ class CombineSelectMaskedLoadPattern : public RewritePattern {
rewriter.replaceOpWithNewOp<LoadOp>(
op, loadOp.getPtr(), loadOp.getMask(), /*other=*/falseValue,
loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(),
loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile());
loadOp.getSemAttr(), loadOp.getScopeAttr(), loadOp.getCache(),
loadOp.getEvict(), loadOp.getIsVolatile());
return success();
}
};
Expand Down
6 changes: 4 additions & 2 deletions lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,12 +311,14 @@ class RewriteTensorPointerPass
// Create a new operation
if (auto loadOp = dyn_cast<triton::LoadOp>(op)) {
auto newResult = builder.create<triton::LoadOp>(
loadOp.getLoc(), newPtr, newMask, newOther, loadOp.getCache(),
loadOp.getEvict(), loadOp.getIsVolatile());
loadOp.getLoc(), newPtr, newMask, newOther, loadOp.getSem(),
loadOp.getScope(), loadOp.getCache(), loadOp.getEvict(),
loadOp.getIsVolatile());
op->getResult(0).replaceAllUsesWith(newResult);
} else if (auto storeOp = dyn_cast<triton::StoreOp>(op)) {
builder.create<triton::StoreOp>(storeOp.getLoc(), newPtr,
storeOp.getValue(), newMask,
storeOp.getSem(), storeOp.getScope(),
storeOp.getCache(), storeOp.getEvict());
}

Expand Down
24 changes: 15 additions & 9 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1276,16 +1276,20 @@ void init_triton_ir(py::module &&m) {
})
// Input/Output
.def("create_load",
[](TritonOpBuilder &self, Value &ptrs, CacheModifier cacheModifier,
EvictionPolicy evictionPolicy, bool isVolatile) -> Value {
return self.create<LoadOp>(ptrs, cacheModifier, evictionPolicy,
isVolatile);
[](TritonOpBuilder &self, Value &ptrs,
std::optional<MemSemantic> sem, std::optional<MemSyncScope> scope,
CacheModifier cacheModifier, EvictionPolicy evictionPolicy,
bool isVolatile) -> Value {
return self.create<LoadOp>(ptrs, sem, scope, cacheModifier,
evictionPolicy, isVolatile);
})
.def("create_store",
[](TritonOpBuilder &self, Value &ptrs, Value &value,
std::optional<MemSemantic> sem, std::optional<MemSyncScope> scope,
CacheModifier cacheModifier,
EvictionPolicy evictionPolicy) -> void {
self.create<StoreOp>(ptrs, value, cacheModifier, evictionPolicy);
self.create<StoreOp>(ptrs, value, sem, scope, cacheModifier,
evictionPolicy);
})
.def("create_tensor_pointer_load",
[](TritonOpBuilder &self, Value &ptr,
Expand All @@ -1306,17 +1310,19 @@ void init_triton_ir(py::module &&m) {
})
.def("create_masked_load",
[](TritonOpBuilder &self, Value &ptrs, Value &mask,
std::optional<Value> &other, CacheModifier cacheModifier,
std::optional<Value> &other, std::optional<MemSemantic> sem,
std::optional<MemSyncScope> scope, CacheModifier cacheModifier,
EvictionPolicy evictionPolicy, bool isVolatile) -> Value {
return self.create<LoadOp>(ptrs, mask, other.value_or(Value()),
cacheModifier, evictionPolicy,
isVolatile);
sem, scope, cacheModifier,
evictionPolicy, isVolatile);
})
.def("create_masked_store",
[](TritonOpBuilder &self, Value &ptrs, Value &val, Value &mask,
std::optional<MemSemantic> sem, std::optional<MemSyncScope> scope,
CacheModifier cacheModifier,
EvictionPolicy evictionPolicy) -> void {
self.create<StoreOp>(ptrs, val, mask, cacheModifier,
self.create<StoreOp>(ptrs, val, mask, sem, scope, cacheModifier,
evictionPolicy);
})
.def("create_reinterpret_tensor_descriptor",
Expand Down
Loading