Skip to content

Commit

Permalink
fix(compiler): [GPU backend] Restrict SDFG generation and batching to…
Browse files Browse the repository at this point in the history
… bootstrapping subgraphs to limit overhead. Restricts loop parallelism in loop nests where SDFG put/get operations occur as they have side effects.
  • Loading branch information
antoniupop committed Apr 25, 2024
1 parent e442a46 commit 0414c0f
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/SCF/IR/SCF.h>

#include <concretelang/Dialect/Concrete/IR/ConcreteDialect.h>
#include <concretelang/Dialect/Concrete/IR/ConcreteOps.h>
#include <concretelang/Dialect/Concrete/IR/ConcreteTypes.h>

namespace SDFG = mlir::concretelang::SDFG;
namespace Concrete = mlir::concretelang::Concrete;

namespace {
enum class StreamMappingKind { ON_DEVICE, TO_DEVICE, SPLICE, TO_HOST, NONE };
Expand Down Expand Up @@ -82,16 +87,6 @@ void unrollLoopsWithSDFGConvertibleOps(mlir::func::FuncOp func) {
}
}

void restrictParallelLoopsWithSDFGConvertibleOps(mlir::func::FuncOp func,
mlir::IRRewriter &rewriter) {
func.walk([&](SDFG::SDFGConvertibleOpInterface convertible) {
for (mlir::Operation *parent = convertible->getParentOp(); parent;
parent = parent->getParentOp())
if (mlir::scf::ForOp forOp = llvm::dyn_cast<mlir::scf::ForOp>(parent))
forOp->setAttr("parallel", rewriter.getBoolAttr(false));
});
}

StreamMappingKind determineStreamMappingKind(mlir::Value v) {
// Determine stream type for operands:
//
Expand Down Expand Up @@ -158,8 +153,6 @@ struct ExtractSDFGOpsPass : public ExtractSDFGOpsBase<ExtractSDFGOpsPass> {

if (unroll)
unrollLoopsWithSDFGConvertibleOps(func);
else
restrictParallelLoopsWithSDFGConvertibleOps(func, rewriter);

mlir::DenseMap<mlir::Value, SDFG::MakeStream> processOutMapping;
mlir::DenseMap<mlir::Value, SDFG::MakeStream> processInMapping;
Expand All @@ -169,8 +162,44 @@ struct ExtractSDFGOpsPass : public ExtractSDFGOpsBase<ExtractSDFGOpsPass> {

unsigned streamNumber = 0;

// Restrict SDFG conversion to cases where the SDFG graph includes
// operations with sufficient computational complexity to benefit
// from offloading to an accelerator.
auto isOpInBootstrappingSDFG = [&](mlir::Operation *op) -> bool {
mlir::scf::ForOp loopParent =
llvm::dyn_cast_or_null<mlir::scf::ForOp>(op->getParentOp());
if (loopParent) {
for (mlir::Operation &bop : loopParent.getBody()->getOperations())
if (llvm::isa<Concrete::BatchedBootstrapLweTensorOp,
Concrete::BatchedMappedBootstrapLweTensorOp,
Concrete::BatchedKeySwitchLweTensorOp>(bop))
return true;
return false;
} else {
return true;
}
};
// If we will generate SDFG ops for a loop, then the loop and
// all enclosing loops must not be parallelized as SDFG access
// ops (put/get) need to be issued in serial order.
auto restrictParallelLoopsWithSDFGConvertibleOps =
[&](mlir::Operation *op) {
mlir::scf::ForOp loopParent =
llvm::dyn_cast_or_null<mlir::scf::ForOp>(op->getParentOp());
if (loopParent) {
loopParent->setAttr("parallel", rewriter.getBoolAttr(false));
for (mlir::Operation *parent = loopParent->getParentOp(); parent;
parent = parent->getParentOp())
if (mlir::scf::ForOp forOp =
llvm::dyn_cast<mlir::scf::ForOp>(parent))
forOp->setAttr("parallel", rewriter.getBoolAttr(false));
}
};
func.walk([&](SDFG::SDFGConvertibleOpInterface op) {
convertibleOps.push_back(op);
if (isOpInBootstrappingSDFG(op)) {
restrictParallelLoopsWithSDFGConvertibleOps(op);
convertibleOps.push_back(op);
}
});

if (convertibleOps.size() == 0)
Expand Down
17 changes: 16 additions & 1 deletion compilers/concrete-compiler/compiler/lib/Transforms/Batching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
// https://github.com/zama-ai/concrete/blob/main/LICENSE.txt
// for license information.

#include <concretelang/Dialect/TFHE/IR/TFHEDialect.h>
#include <concretelang/Dialect/TFHE/IR/TFHEOps.h>
#include <concretelang/Dialect/TFHE/IR/TFHETypes.h>
#include <functional>
#include <limits>
#include <llvm/ADT/STLExtras.h>
Expand Down Expand Up @@ -1047,7 +1050,19 @@ class BatchingPattern : public mlir::OpRewritePattern<mlir::func::FuncOp> {
// Predicate checking whether an scf.for op is a valid candidate
// to expand the loop nest upwards towards the outermost loop
auto isCandidateLoop = [](mlir::scf::ForOp forOp) -> bool {
return isStaticLoop(forOp);
std::function<bool(mlir::scf::ForOp forOp)> hasKSorBS =
[&](mlir::scf::ForOp forOp) -> bool {
for (mlir::Operation &op : forOp.getBody()->getOperations()) {
if (llvm::isa<TFHE::KeySwitchGLWEOp, TFHE::BootstrapGLWEOp>(op))
return true;
if (auto nested = llvm::dyn_cast_or_null<mlir::scf::ForOp>(op);
nested)
if (hasKSorBS(nested))
return true;
}
return false;
};
return hasKSorBS(forOp) && isStaticLoop(forOp);
};

// Only batchable operations within at least one loop are of
Expand Down

0 comments on commit 0414c0f

Please sign in to comment.