Skip to content

Commit

Permalink
Add some canonicalizations for ConcatOp and ExtractRefOp (NVIDIA#574)
Browse files Browse the repository at this point in the history
* Add some canonicalizations for ConcatOp and ExtractRefOp

Signed-off-by: Alex McCaskey <[email protected]>
  • Loading branch information
amccaskey authored Aug 24, 2023
1 parent f440d60 commit 310d268
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 10 deletions.
70 changes: 67 additions & 3 deletions lib/Optimizer/Dialect/Quake/QuakeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,38 @@ void quake::AllocaOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
//===----------------------------------------------------------------------===//

namespace {
// %7 = quake.concat %4 : (!quake.veq<2>) -> !quake.veq<2>
// ───────────────────────────────────────────
// removed
struct ConcatNoOpPattern : public OpRewritePattern<quake::ConcatOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(quake::ConcatOp concat,
PatternRewriter &rewriter) const override {
// Remove concat veq<N> -> veq<N>
// or
// concat ref -> ref
auto qubitsToConcat = concat.getQbits();
if (qubitsToConcat.size() > 1)
return failure();

// We only want to handle veq -> veq here.
if (isa<quake::RefType>(qubitsToConcat.front().getType())) {
return failure();
}

// Do not handle anything where we don't know the sizes.
auto retTy = concat.getResult().getType();
if (auto veqTy = dyn_cast<quake::VeqType>(retTy))
if (!veqTy.hasSpecifiedSize())
// This could be a folded quake.relax_size op.
return failure();

rewriter.replaceOp(concat, qubitsToConcat);
return success();
}
};

struct ConcatSizePattern : public OpRewritePattern<quake::ConcatOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -130,7 +162,7 @@ struct ConcatSizePattern : public OpRewritePattern<quake::ConcatOp> {

void quake::ConcatOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<ConcatSizePattern>(context);
patterns.add<ConcatSizePattern, ConcatNoOpPattern>(context);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -168,6 +200,38 @@ static void printRawIndex(OpAsmPrinter &printer, quake::ExtractRefOp refOp,
}

namespace {
// %4 = quake.concat %2, %3 : (!quake.ref, !quake.ref) -> !quake.veq<2>
// %7 = quake.extract_ref %4[0] : (!quake.veq<2>) -> !quake.ref
// ───────────────────────────────────────────
// replace all use with %2
struct ForwardConcatExtractPattern
: public OpRewritePattern<quake::ExtractRefOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(quake::ExtractRefOp extract,
PatternRewriter &rewriter) const override {
auto veq = extract.getVeq();
auto concatOp = veq.getDefiningOp<quake::ConcatOp>();
if (concatOp && extract.hasConstantIndex()) {
// Don't run this canonicalization if any of the operands
// to concat are of type veq.
auto concatQubits = concatOp.getQbits();
for (auto qOp : concatQubits)
if (isa<quake::VeqType>(qOp.getType()))
return failure();

// concat only has ref type operands.
auto index = extract.getConstantIndex();
if (index < concatQubits.size()) {
auto qOpValue = concatQubits[index];
if (isa<quake::RefType>(qOpValue.getType()))
rewriter.replaceOp(extract, {qOpValue});
}
}
return success();
}
};

// %2 = quake.concat %1 : (!quake.ref) -> !quake.veq<1>
// %3 = quake.extract_ref %2[0] : (!quake.veq<1>) -> !quake.ref
// quake.* %3 ...
Expand Down Expand Up @@ -198,8 +262,8 @@ struct ForwardConcatExtractSingleton

void quake::ExtractRefOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<FuseConstantToExtractRefPattern, ForwardConcatExtractSingleton>(
context);
patterns.add<FuseConstantToExtractRefPattern, ForwardConcatExtractSingleton,
ForwardConcatExtractPattern>(context);
}

LogicalResult quake::ExtractRefOp::verify() {
Expand Down
19 changes: 12 additions & 7 deletions lib/Target/OpenQASM/TranslateToOpenQASM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,14 +235,19 @@ static LogicalResult emitOperation(Emitter &emitter, func::CallOp callOp) {

static LogicalResult emitOperation(Emitter &emitter,
quake::OperatorInterface optor) {
// TODO: Handle adjoint for T and S
if (optor.isAdj())
return optor.emitError("cannot convert adjoint operations to OpenQASM 2.0");

StringRef name;
// Handle adjoint for T and S
StringRef name = "";
if (failed(translateOperatorName(optor, name)))
return optor.emitError("cannot convert operation to OpenQASM 2.0");
emitter.os << name;
return optor.emitError("cannot convert operation to OpenQASM 2.0.");

if (optor.isAdj()) {
std::vector<std::string> validAdjointOps{"s", "t"};
if (std::find(validAdjointOps.begin(), validAdjointOps.end(), name.str()) ==
validAdjointOps.end())
return optor.emitError("cannot create adjoint for this operation.");
emitter.os << name << "dg";
} else
emitter.os << name;

if (failed(printParameters(emitter, optor.getParameters())))
return optor.emitError("failed to emit parameters");
Expand Down

0 comments on commit 310d268

Please sign in to comment.