Skip to content

Commit

Permalink
Integrating OpenBLAS for matrix multiplication
Browse files Browse the repository at this point in the history
  • Loading branch information
parsifal-47 committed Jan 4, 2025
1 parent a7ffd7d commit 260c07f
Show file tree
Hide file tree
Showing 12 changed files with 404 additions and 2 deletions.
12 changes: 11 additions & 1 deletion backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,21 @@
import functools
from pathlib import Path


def _get_triton_shared_opt_path() -> str:
path = os.getenv("TRITON_SHARED_OPT_PATH", "")
if path == "":
raise Exception("TRITON_SHARED_OPT_PATH is not set.")
return path


# because of the way triton loads backends, this function is duplicated
# in compiler and driver
def _get_triton_shared_use_openblas() -> bool:
use_blas = os.getenv("TRITON_SHARED_USE_OPENBLAS", "")
return use_blas != ""


def _get_llvm_bin_path(bin_name: str) -> str:
path = os.getenv("LLVM_BINARY_DIR", "")
if path == "":
Expand All @@ -42,7 +50,9 @@ def _ttir_to_ttsharedir(mod):
dst_path = os.path.join(tmpdir, "ttshared.mlir")
Path(src_path).write_text(ttir_code)
triton_shared_opt_path = _get_triton_shared_opt_path()
subprocess.check_call([triton_shared_opt_path, src_path, "--triton-to-linalg-experimental", "--mlir-print-debuginfo", "-o", dst_path])
extra_pass = ["--linalg-to-linear-algebra-subprograms"] if _get_triton_shared_use_openblas() else []
subprocess.check_call([triton_shared_opt_path, src_path, "--triton-to-linalg-experimental"] + \
extra_pass + ["--mlir-print-debuginfo", "-o", dst_path])
_dump_ir_if_needed([src_path])
return Path(dst_path).read_text()

Expand Down
11 changes: 10 additions & 1 deletion backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@
from triton.backends.driver import DriverBase
from triton.backends.compiler import GPUTarget


# because of the way triton loads backends, this function is duplicated
# in compiler and driver
def _get_triton_shared_use_openblas() -> bool:
use_blas = os.getenv("TRITON_SHARED_USE_OPENBLAS", "")
return use_blas != ""


# -------------------- Launcher ----------------------------
def _ty_to_cpp(ty):
if ty[0] == '*':
Expand Down Expand Up @@ -253,11 +261,12 @@ def launch(
so_path = os.path.join(tmpdir, "kernel.so")
Path(asm_src_path).write_bytes(asm_src)
Path(launcher_src_path).write_text(src)
extra_lib = ["-lopenblas"] if _get_triton_shared_use_openblas() else []
# Compile it together.
subprocess.check_call([
"g++", "-std=c++17", launcher_src_path, asm_src_path,
f"-I{py_include_dir}", f"-I{include_dir}", f"-L{py_lib_dir}",
"-shared", f"-l{py_lib}", "-fPIC", "-o", so_path
"-shared", "-fPIC"] + extra_lib + ["-o", so_path
])

with open(so_path, "rb") as f:
Expand Down
1 change: 1 addition & 0 deletions include/triton-shared/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ add_subdirectory(TritonToLinalgExperimental)
add_subdirectory(TritonToStructured)
add_subdirectory(TritonArithToLinalg)
add_subdirectory(StructuredToMemref)
add_subdirectory(LinalgToLinearAlgebraSubprograms)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name LinalgToLinearAlgebraSubprograms)
add_public_tablegen_target(LinalgToLinearAlgebraSubprogramsConversionPassIncGen)
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef LINALG_TO_LINEAR_ALGEBRA_SUBPROGRAMS_H
#define LINALG_TO_LINEAR_ALGEBRA_SUBPROGRAMS_H

#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

#include "triton/Dialect/Triton/IR/Dialect.h"

namespace mlir {
namespace triton {

#define GEN_PASS_DECL
#include "triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/Passes.h.inc"

void populateLinalgToLinearAlgebraSubprogramsConversionPatterns(bool pidsToFuncArgs,
bool addptrToLinalg,
bool assertToCf,
RewritePatternSet &patterns);

std::unique_ptr<OperationPass<ModuleOp>> createLinalgToLinearAlgebraSubprogramsPass();

} // namespace triton
} // namespace mlir

#endif // LINALG_TO_LINEAR_ALGEBRA_SUBPROGRAMS_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef LINALG_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES_H
#define LINALG_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES_H

#include "triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/LinalgToLinearAlgebraSubprograms.h"

namespace mlir {
namespace triton {

#define GEN_PASS_REGISTRATION
#include "triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/Passes.h.inc"

} // namespace triton
} // namespace mlir

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#ifndef LINALG_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES
#define LINALG_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES

include "mlir/Pass/PassBase.td"

def LinalgToLinearAlgebraSubprograms : Pass<"linalg-to-linear-algebra-subprograms", "mlir::ModuleOp"> {
let summary = "Convert Linalg operations to library calls";
}

#endif
1 change: 1 addition & 0 deletions lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ add_subdirectory(TritonToLinalgExperimental)
add_subdirectory(TritonToStructured)
add_subdirectory(TritonArithToLinalg)
add_subdirectory(StructuredToMemref)
add_subdirectory(LinalgToLinearAlgebraSubprograms)
21 changes: 21 additions & 0 deletions lib/Conversion/LinalgToLinearAlgebraSubprograms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
add_triton_library(LinalgToLinearAlgebraSubprograms
LinalgToLinearAlgebraSubprogramsPass.cpp

DEPENDS
LinalgToLinearAlgebraSubprogramsConversionPassIncGen

LINK_LIBS PUBLIC
MLIRLinalgTransforms
MLIRArithDialect
MLIRDialectUtils
MLIRIR
MLIRMathDialect
MLIRPass
MLIRTensorDialect
MLIRTransforms
MLIRSupport
TritonIR
TritonTransforms
TritonTilingExtIR
TritonStructuredIR
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
//===----------------------------------------------------------------------===//
//
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
//
//===----------------------------------------------------------------------===//

#include "triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/LinalgToLinearAlgebraSubprograms.h"
#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h"
#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "linalg-to-las"

using namespace mlir;
using namespace triton;

namespace mlir {
namespace triton {
#define GEN_PASS_DEF_LINALGTOLINEARALGEBRASUBPROGRAMS
#include "triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/Passes.h.inc"
} // namespace triton
} // namespace mlir

namespace {

struct MatmulConverter : public OpConversionPattern<linalg::MatmulOp> {
using OpConversionPattern<linalg::MatmulOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(linalg::MatmulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();

if (op.getInputs().size() != 2) {
LLVM_DEBUG(llvm::dbgs() << "Cannot replace, must be exactly two input matrices\n");
return failure();
}

Operation *resultOp;

Value A = op.getInputs()[0];
Value B = op.getInputs()[1];
Value C;

Value matmulResult = op.getResults()[0];
bool otherUsers = false;
bool found = false;

for (Operation *user : matmulResult.getUsers()) {
if (auto addFOp = dyn_cast<arith::AddFOp>(user)) {
if (!found) {
found = true;
C = addFOp.getLhs() == matmulResult ? addFOp.getRhs() : addFOp.getLhs();
resultOp = addFOp;
continue;
}
}
otherUsers = true;
}

bool replacingFOp = true;
if (otherUsers || !found) {
C = op.getOutputs()[0];
resultOp = op;
replacingFOp = false;
}

auto tensorA = cast<RankedTensorType>(A.getType());
auto tensorB = cast<RankedTensorType>(B.getType());
auto tensorC = cast<RankedTensorType>(C.getType());

if (tensorA.getElementType() != tensorB.getElementType() ||
tensorC.getElementType() != tensorB.getElementType()) {
LLVM_DEBUG(llvm::dbgs() << "Cannot replace, different element types\n");
return failure();
}

if (!tensorA.getElementType().isF32() && !tensorA.getElementType().isF64()) {
LLVM_DEBUG(llvm::dbgs() << "Cannot replace, unsupported type\n");
return failure();
}

auto floatType = tensorA.getElementType();

// since tensors are immutable, we need to allocate a buffer for the result
Value memrefConst = rewriter.create<bufferization::ToMemrefOp>(loc, MemRefType::get(tensorC.getShape(), tensorC.getElementType()), C);
auto memrefType = MemRefType::get(tensorC.getShape(), floatType);
Value memrefC = rewriter.create<memref::AllocOp>(loc, memrefType);
auto copyOp = rewriter.create<linalg::CopyOp>(loc, ValueRange{memrefConst}, ValueRange{memrefC});

ModuleOp module = op->getParentOfType<ModuleOp>();

auto intType = rewriter.getI32Type();
auto int64Type = rewriter.getI64Type();
auto ptrType = LLVM::LLVMPointerType::get(op.getContext(), 0); // default address space

auto funcType = FunctionType::get(op.getContext(),
{intType, intType, intType, intType, intType, intType, floatType,
ptrType, intType, ptrType, intType, floatType,
ptrType, intType}, {});

bool usingF64 = floatType.isF64();
const char *funcName = usingF64 ? "cblas_dgemm" : "cblas_sgemm";
auto func = module.lookupSymbol<func::FuncOp>(funcName);
if (!func) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
func = rewriter.create<func::FuncOp>(loc, funcName, funcType);
func.setVisibility(SymbolTable::Visibility::Private);
}

auto memrefToPointer = [&rewriter, &loc, &int64Type, &ptrType](Value &memref) {
auto indexPtr = rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(loc, memref);
auto castOp = rewriter.create<arith::IndexCastOp>(loc, int64Type, indexPtr);
return rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, castOp);
};

auto tensorToPointer = [&rewriter, &loc, &memrefToPointer](Value &V, RankedTensorType &T) {
if (auto tensorOp = V.getDefiningOp<bufferization::ToTensorOp>()) {
Value ref = tensorOp.getMemref();
return memrefToPointer(ref);
}

Value memref = rewriter.create<bufferization::ToMemrefOp>(loc, MemRefType::get(T.getShape(), T.getElementType()), V);
return memrefToPointer(memref);
};

Value ptrA = tensorToPointer(A, tensorA);
Value ptrB = tensorToPointer(B, tensorB);
Value ptrC = memrefToPointer(memrefC);

int32_t M = tensorA.getShape()[0];
int32_t K = tensorA.getShape()[1];
int32_t N = tensorB.getShape()[1];

Value alpha = rewriter.create<arith::ConstantOp>(loc, floatType, usingF64 ? rewriter.getF64FloatAttr(1.0) : rewriter.getF32FloatAttr(1.0));
Value beta = alpha;

auto constOp = [&rewriter, &loc, &intType](int32_t V) {
return rewriter.create<arith::ConstantOp>(loc, intType, rewriter.getI32IntegerAttr(V));
};

// constants below are from OpenBLAS library, check variable names for interpretation
// for more information check: https://github.com/OpenMathLib/OpenBLAS/blob/develop/cblas.h
Value CblasRowMajor = constOp(101), CblasNoTrans = constOp(111);
Value MVal = constOp(M), NVal = constOp(N), KVal = constOp(K);
Value LDA = KVal, LDB = NVal, LDC = NVal;

auto funcOp = rewriter.create<func::CallOp>(loc, func, ValueRange{
CblasRowMajor, CblasNoTrans, CblasNoTrans,
MVal, NVal, KVal,
alpha, ptrA, LDA,
ptrB, LDB, beta,
ptrC, LDC
});

auto toTensorOp = rewriter.create<bufferization::ToTensorOp>(loc,
tensorC, memrefC, true /* restrict */, true /* writable */);

if (!replacingFOp) {
rewriter.replaceOp(op, toTensorOp);
} else {
rewriter.eraseOp(op);
rewriter.replaceOp(resultOp, toTensorOp);
}

return success();
}
};

class LinalgToLinearAlgebraSubprogramsPass
: public triton::impl::LinalgToLinearAlgebraSubprogramsBase<LinalgToLinearAlgebraSubprogramsPass> {
using LinalgToLinearAlgebraSubprogramsBase<
LinalgToLinearAlgebraSubprogramsPass>::LinalgToLinearAlgebraSubprogramsBase;

public:
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<linalg::LinalgDialect, func::FuncDialect, arith::ArithDialect, math::MathDialect, bufferization::BufferizationDialect,
affine::AffineDialect, scf::SCFDialect, tensor::TensorDialect, LLVM::LLVMDialect>();
}

void runOnOperation() override {
auto moduleOp = getOperation();
RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());

patterns.add<MatmulConverter>(patterns.getContext());

target.addLegalDialect<
func::FuncDialect, arith::ArithDialect, math::MathDialect,
affine::AffineDialect, scf::SCFDialect, linalg::LinalgDialect,
cf::ControlFlowDialect, tensor::TensorDialect,
bufferization::BufferizationDialect, memref::MemRefDialect, LLVM::LLVMDialect>();


target.addDynamicallyLegalOp<linalg::MatmulOp>([](linalg::MatmulOp op) {
Value A = op.getInputs()[0];
Value B = op.getInputs()[1];

auto tensorA = cast<RankedTensorType>(A.getType());
auto tensorB = cast<RankedTensorType>(B.getType());

if (tensorA.getElementType() != tensorB.getElementType()) {
// no need to replace if types are different
return true;
}

if (!tensorA.getElementType().isF32() && !tensorA.getElementType().isF64()) {
// unsupported types
return true;
}

return false; // MatmulOp is illegal, and transformation is needed
});

if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();
}
}
};

} // namespace

std::unique_ptr<OperationPass<ModuleOp>>
triton::createLinalgToLinearAlgebraSubprogramsPass() {
return std::make_unique<LinalgToLinearAlgebraSubprogramsPass>();
}
Loading

0 comments on commit 260c07f

Please sign in to comment.