Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
nhat-nguyen committed Jan 10, 2025
1 parent 999c080 commit 7a53f67
Show file tree
Hide file tree
Showing 32 changed files with 839 additions and 653 deletions.
2 changes: 1 addition & 1 deletion include/triton-shared/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ add_subdirectory(TritonToLinalg)
add_subdirectory(TritonToLinalgExperimental)
add_subdirectory(TritonToStructured)
add_subdirectory(TritonArithToLinalg)
add_subdirectory(FoldUnstructuredPtr)
add_subdirectory(TritonToUnstructured)
add_subdirectory(StructuredToMemref)

This file was deleted.

This file was deleted.

15 changes: 0 additions & 15 deletions include/triton-shared/Conversion/FoldUnstructuredPtr/Passes.h

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToUnstructured)
add_public_tablegen_target(TritonToUnstructuredConversionPassIncGen)
15 changes: 15 additions & 0 deletions include/triton-shared/Conversion/TritonToUnstructured/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef TRITON_TO_UNSTRUCTURED_CONVERSION_PASSES_H
#define TRITON_TO_UNSTRUCTURED_CONVERSION_PASSES_H

#include "triton-shared/Conversion/TritonToUnstructured/TritonToUnstructured.h"

namespace mlir {
namespace triton {

#define GEN_PASS_REGISTRATION
#include "triton-shared/Conversion/TritonToUnstructured/Passes.h.inc"

} // namespace triton
} // namespace mlir

#endif
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#ifndef FOLD_UNSTRUCTURED_PTR_CONVERSION_PASSES
#define FOLD_UNSTRUCTURED_PTR_CONVERSION_PASSES
#ifndef TRITON_TO_UNSTRUCTURED_CONVERSION_PASSES
#define TRITON_TO_UNSTRUCTURED_CONVERSION_PASSES

include "mlir/Pass/PassBase.td"

def FoldUnstructuredPtr : Pass<"fold-unstructured-ptr", "mlir::ModuleOp"> {
def TritonToUnstructured : Pass<"triton-to-unstructured", "mlir::ModuleOp"> {
let summary = "Transforms tt.addptr ops into offset accumulation ops";
let constructor = "triton::createFoldUnstructuredPtrPass()";
let constructor = "triton::createTritonToUnstructuredPass()";
let options = [
Option<"offsetBitWidth", "offset-bit-width", "size_t", /*default*/"32",
"Bitwidth used for the starting offset of each pointer">
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef TRITON_CONVERSION_TRITON_TO_UNSTRUCTURED_TRITON_TO_UNSTRUCTURED_H
#define TRITON_CONVERSION_TRITON_TO_UNSTRUCTURED_TRITON_TO_UNSTRUCTURED_H

#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

#include "triton/Dialect/Triton/IR/Dialect.h"

namespace mlir {
namespace triton {

std::unique_ptr<OperationPass<ModuleOp>> createTritonToUnstructuredPass();

} // namespace triton
} // namespace mlir

#endif // TRITON_CONVERSION_TRITON_TO_UNSTRUCTURED_TRITON_TO_UNSTRUCTURED_H
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
#ifndef MLIR_DIALECT_TRITON_STRUCTURED_IR_TRITON_STRUCTURED_DIALECT_H_
#define MLIR_DIALECT_TRITON_STRUCTURED_IR_TRITON_STRUCTURED_DIALECT_H_

#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#include "triton/Dialect/Triton/IR/Dialect.h"

#include "mlir/IR/Dialect.h"
namespace mlir {
namespace tts {
namespace utils {
mlir::Value getScalarValue(mlir::Value operand, mlir::Location loc,
mlir::OpBuilder &builder);
}
} // namespace tts
} // namespace mlir

//===----------------------------------------------------------------------===//
// TritonStructured Operations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,51 @@ def TTS_MakeUnstructuredTensorPtrOp : TTS_Op<"make_unstructured_tptr", [Pure]> {
let results = (outs TT_PtrLike:$ptr);
}

def TTS_GatherOp : TTS_Op<"gather", [
MemoryEffects<[MemRead]>,
AttrSizedOperandSegments,
OptionalTypesMatchWith<"mask type matches ptr type", "offset", "mask", "triton::getI1SameShape($_self)">,
OptionalTypesMatchWith<"other matches ptr type", "ptr", "other", "triton::getPointeeType($_self)">
]> {
let summary = "optionally load data from in memory to fill a portion of the tensor";

let arguments = (
ins
TT_Ptr:$ptr,
TT_IntLike:$offset,
Optional<TT_BoolLike>:$mask,
Optional<TT_Type>:$other
);

let results = (outs TT_Type:$result);

let assemblyFormat = [{
$ptr `[` $offset `]` (`mask` `=` $mask^)? (`default` `=` $other^)?
attr-dict `:` `(` type($ptr) `,` type($offset) `)` `->` type($result)
}];
}

def TTS_ScatterOp : TTS_Op<"scatter", [
MemoryEffects<[MemWrite]>,
OptionalTypesMatchWith<"mask type matches offset type", "offset", "mask",
"triton::getI1SameShape($_self)">
]> {
let summary = "optionally load data from in memory to fill a portion of the tensor";

let arguments = (
ins
TT_Ptr:$ptr,
TT_IntLike:$offset,
TT_Type:$value,
Optional<TT_BoolLike>:$mask
);

let assemblyFormat = [{
$value `into` $ptr `[` $offset `]` (`mask` `=` $mask^)?
attr-dict `:` type($value) `into` ` ` `(` type($ptr) `,` type($offset) `)`
}];
}

def TTS_LoadOp : TTS_Op<"load", [
MemoryEffects<[MemRead]>,
AttrSizedOperandSegments
Expand All @@ -174,7 +219,7 @@ def TTS_LoadOp : TTS_Op<"load", [
}

bool hasMask() {
return !getStaticMaskDims().empty();
return !getMixedMaskDims().empty();
}
}];

Expand Down Expand Up @@ -205,7 +250,7 @@ def TTS_StoreOp : TTS_Op<"store", [
}

bool hasMask() {
return !getStaticMaskDims().empty();
return !getMixedMaskDims().empty();
}
}];

Expand Down
74 changes: 1 addition & 73 deletions lib/AnalysisStructured/PtrAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@
#include "triton-shared/AnalysisStructured/PtrAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
Expand All @@ -33,7 +30,6 @@
#include "llvm/Support/LogicalResult.h"
#include <cassert>
#include <cstddef>
#include <functional>
#include <optional>
#include <queue>
#include <string>
Expand All @@ -42,74 +38,6 @@

namespace mlir {

// Extract a scalar value from v.
// If v is a scalar, return that directly. Otherwise, parse through operations
// (currently only support splat, sitofp, and truncf) that produce it to
// extract the underlying scalar value. We then reconstruct the chain of
// operations that can produce this constant with the original type. If no
// scalar value can be extracted, a nullptr is returned.
static Value getScalarValue(Value operand, Location loc, OpBuilder &builder) {
SmallVector<Operation *> ops;

auto reconstructScalarValue = [&](Value src) {
for (auto op = ops.rbegin(); op != ops.rend(); ++op) {
src = TypeSwitch<Operation *, Value>(*op)
.Case<arith::SIToFPOp>([&](Operation *op) {
auto resType = op->getResults()[0].getType();
if (auto shapedType = dyn_cast<ShapedType>(resType)) {
resType = shapedType.getElementType();
}
return builder.create<arith::SIToFPOp>(loc, resType, src);
})
.Case<arith::TruncFOp>([&](Operation *op) {
auto resType = op->getResults()[0].getType();
if (auto shapedType = dyn_cast<ShapedType>(resType)) {
resType = shapedType.getElementType();
}
return builder.create<arith::TruncFOp>(loc, resType, src);
})
.Default([](Operation *op) {
llvm_unreachable("unsupported op in generating ");
return nullptr;
});
}
return src;
};

while (true) {
if (!dyn_cast<ShapedType>(operand.getType())) {
return reconstructScalarValue(operand);
} else if (auto op = operand.getDefiningOp<arith::ConstantOp>()) {
if (auto attr = dyn_cast<DenseElementsAttr>(op.getValue())) {
if (!attr.isSplat()) {
InFlightDiagnostic diag = emitError(loc)
<< "other value used in masked load "
"produced by unsupported instruction";
return nullptr;
}
auto elemValue = attr.getSplatValue<Attribute>();
auto constOp = arith::ConstantOp::materialize(
builder, elemValue, attr.getElementType(), op.getLoc());
return reconstructScalarValue(constOp.getResult());
}
} else if (auto op = operand.getDefiningOp<triton::SplatOp>()) {
operand = op.getSrc();
} else if (auto op = operand.getDefiningOp<arith::SIToFPOp>()) {
ops.push_back(op.getOperation());
operand = op.getIn();
} else if (auto op = operand.getDefiningOp<arith::TruncFOp>()) {
ops.push_back(op.getOperation());
operand = op.getIn();
} else {
InFlightDiagnostic diag = emitError(loc)
<< "other value used in masked load produced "
"by unsupported instruction";
return nullptr;
}
}
return nullptr;
}

namespace tts {

int32_t PtrState::getRank() const {
Expand Down Expand Up @@ -1126,7 +1054,7 @@ LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op) {
if (other) {
assert(mask && "other value used while no masks are specified");

scalarOther = getScalarValue(other, loc, builder);
scalarOther = utils::getScalarValue(other, loc, builder);
if (!scalarOther) {
op->emitRemark("other value used in masked load produced by "
"unsupported instruction");
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
add_subdirectory(FoldUnstructuredPtr)
add_subdirectory(TritonToLinalg)
add_subdirectory(TritonToLinalgExperimental)
add_subdirectory(TritonToStructured)
add_subdirectory(TritonToUnstructured)
add_subdirectory(TritonArithToLinalg)
add_subdirectory(StructuredToMemref)
Loading

0 comments on commit 7a53f67

Please sign in to comment.