diff --git a/lib/Optimizer/Dialect/Quake/QuakeOps.cpp b/lib/Optimizer/Dialect/Quake/QuakeOps.cpp index 2889baccef..9933bfdacd 100644 --- a/lib/Optimizer/Dialect/Quake/QuakeOps.cpp +++ b/lib/Optimizer/Dialect/Quake/QuakeOps.cpp @@ -94,6 +94,38 @@ void quake::AllocaOp::getCanonicalizationPatterns(RewritePatternSet &patterns, //===----------------------------------------------------------------------===// namespace { +// %7 = quake.concat %4 : (!quake.veq<2>) -> !quake.veq<2> +// ─────────────────────────────────────────── +// removed +struct ConcatNoOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(quake::ConcatOp concat, + PatternRewriter &rewriter) const override { + // Remove concat veq -> veq + // or + // concat ref -> ref + auto qubitsToConcat = concat.getQbits(); + if (qubitsToConcat.size() > 1) + return failure(); + + // We only want to handle veq -> veq here. + if (isa(qubitsToConcat.front().getType())) { + return failure(); + } + + // Do not handle anything where we don't know the sizes. + auto retTy = concat.getResult().getType(); + if (auto veqTy = dyn_cast(retTy)) + if (!veqTy.hasSpecifiedSize()) + // This could be a folded quake.relax_size op. + return failure(); + + rewriter.replaceOp(concat, qubitsToConcat); + return success(); + } +}; + struct ConcatSizePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -130,7 +162,7 @@ struct ConcatSizePattern : public OpRewritePattern { void quake::ConcatOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add(context); + patterns.add(context); } //===----------------------------------------------------------------------===// @@ -168,6 +200,38 @@ static void printRawIndex(OpAsmPrinter &printer, quake::ExtractRefOp refOp, } namespace { +// %4 = quake.concat %2, %3 : (!quake.ref, !quake.ref) -> !quake.veq<2> +// %7 = quake.extract_ref %4[0] : (!quake.veq<2>) -> !quake.ref +// ─────────────────────────────────────────── +// replace all use with %2 +struct ForwardConcatExtractPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(quake::ExtractRefOp extract, + PatternRewriter &rewriter) const override { + auto veq = extract.getVeq(); + auto concatOp = veq.getDefiningOp(); + if (concatOp && extract.hasConstantIndex()) { + // Don't run this canonicalization if any of the operands + // to concat are of type veq. + auto concatQubits = concatOp.getQbits(); + for (auto qOp : concatQubits) + if (isa(qOp.getType())) + return failure(); + + // concat only has ref type operands. + auto index = extract.getConstantIndex(); + if (index < concatQubits.size()) { + auto qOpValue = concatQubits[index]; + if (isa(qOpValue.getType())) + rewriter.replaceOp(extract, {qOpValue}); + } + } + return success(); + } +}; + // %2 = quake.concat %1 : (!quake.ref) -> !quake.veq<1> // %3 = quake.extract_ref %2[0] : (!quake.veq<1>) -> !quake.ref // quake.* %3 ... @@ -198,8 +262,8 @@ struct ForwardConcatExtractSingleton void quake::ExtractRefOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { - patterns.add( - context); + patterns.add(context); } LogicalResult quake::ExtractRefOp::verify() { diff --git a/lib/Target/OpenQASM/TranslateToOpenQASM.cpp b/lib/Target/OpenQASM/TranslateToOpenQASM.cpp index f9e2f5fa96..1ca000027b 100644 --- a/lib/Target/OpenQASM/TranslateToOpenQASM.cpp +++ b/lib/Target/OpenQASM/TranslateToOpenQASM.cpp @@ -235,14 +235,19 @@ static LogicalResult emitOperation(Emitter &emitter, func::CallOp callOp) { static LogicalResult emitOperation(Emitter &emitter, quake::OperatorInterface optor) { - // TODO: Handle adjoint for T and S - if (optor.isAdj()) - return optor.emitError("cannot convert adjoint operations to OpenQASM 2.0"); - - StringRef name; + // Handle adjoint for T and S + StringRef name = ""; if (failed(translateOperatorName(optor, name))) - return optor.emitError("cannot convert operation to OpenQASM 2.0"); - emitter.os << name; + return optor.emitError("cannot convert operation to OpenQASM 2.0."); + + if (optor.isAdj()) { + std::vector validAdjointOps{"s", "t"}; + if (std::find(validAdjointOps.begin(), validAdjointOps.end(), name.str()) == + validAdjointOps.end()) + return optor.emitError("cannot create adjoint for this operation."); + emitter.os << name << "dg"; + } else + emitter.os << name; if (failed(printParameters(emitter, optor.getParameters()))) return optor.emitError("failed to emit parameters");