Skip to content

Commit

Permalink
Implement lowering for AccumulateOp.
Browse files Browse the repository at this point in the history
  • Loading branch information
ingomueller-net committed Nov 15, 2022
1 parent fac45e8 commit fca969e
Show file tree
Hide file tree
Showing 4 changed files with 421 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,20 @@ class StateTypeComputer {
TypeConverter typeConverter;
};

/// The state of AccumulateOp consists of the state of its upstream iterator,
/// i.e., the state of the iterator that produces its input stream, and a
/// Boolean indicating whether the iterator has returned a result already (which
/// is initialized to false and set to true in the first call to next in order
/// to ensure that only a single result is returned).
template <>
StateType
StateTypeComputer::operator()(AccumulateOp op,
llvm::SmallVector<StateType> upstreamStateTypes) {
MLIRContext *context = op->getContext();
Type hasReturned = IntegerType::get(context, /*width=*/1);
return StateType::get(context, {upstreamStateTypes[0], hasReturned});
}

/// The state of ConstantStreamOp consists of a single number that corresponds
/// to the index of the next struct returned by the iterator.
template <>
Expand Down Expand Up @@ -168,6 +182,7 @@ mlir::iterators::IteratorAnalysis::IteratorAnalysis(
// TODO: Verify that operands do not come from bbArgs.
.Case<
// clang-format off
AccumulateOp,
ConstantStreamOp,
FilterOp,
MapOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,253 @@ struct PrintOpLowering : public OpConversionPattern<PrintOp> {
}
};

//===----------------------------------------------------------------------===//
// AccumulateOp.
//===----------------------------------------------------------------------===//

/// Builds IR that opens the nested upstream iterator and sets `hasReturned` to
/// false. Possible output:
///
/// %0 = iterators.extractvalue %arg0[0] :
/// <!upstream_state, i1> -> !upstream_state
/// %1 = call @iterators.upstream.open.0(%0) :
/// (!upstream_state) -> !upstream_state
/// %2 = iterators.insertvalue %arg0[0] (%1 : !upstream_state) :
/// <!upstream_state, i1>
/// %false = arith.constant false
/// %3 = iterators.insertvalue %false into %2[1] :
/// !iterators.state<!upstream_state, i1>
static Value buildOpenBody(AccumulateOp op, OpBuilder &builder,
Value initialState,
ArrayRef<IteratorInfo> upstreamInfos) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);

Type upstreamStateType = upstreamInfos[0].stateType;

// Extract upstream state.
Value initialUpstreamState = b.create<iterators::ExtractValueOp>(
upstreamStateType, initialState, b.getIndexAttr(0));

// Call Open on upstream.
SymbolRefAttr openFunc = upstreamInfos[0].openFunc;
auto openCallOp =
b.create<func::CallOp>(openFunc, upstreamStateType, initialUpstreamState);

// Update upstream state.
Value updatedUpstreamState = openCallOp->getResult(0);
Value updatedState = b.create<iterators::InsertValueOp>(
initialState, b.getIndexAttr(0), updatedUpstreamState);

// Reset hasReturned to false.
Value constFalse = b.create<arith::ConstantIntOp>(/*value=*/0, /*width=*/1);
updatedState = b.create<iterators::InsertValueOp>(
updatedState, b.getIndexAttr(1), constFalse);

return updatedState;
}

/// Builds IR that consumes all elements of the upstream iterator and combines
/// them into a single one using the given accumulate function. Pseudo-code:
///
/// if hasReturned: return {}
/// hasReturned = True
/// accumulator = initFuncRef()
/// while (next = upstream->Next()):
/// accumulator = accumulate(accumulator, next)
/// return accumulator
///
/// Possible output:
///
/// %0 = iterators.extractvalue %arg0[0] :
/// <!upstream_state, i1> -> !upstream_state
/// %1 = iterators.extractvalue %arg0[1] : !iterators.state<!upstream_state, i1>
/// %2:2 = scf.if %1 -> (!upstream_state, !element_type) {
/// %6 = llvm.mlir.undef : !element_type
/// scf.yield %0, %6 : !upstream_state, !element_type
/// } else {
/// %6 = func.call @zero_struct() : () -> !element_type
/// %7:3 = scf.while (%arg1 = %0, %arg2 = %6) :
/// (!upstream_state, !element_type) ->
/// (!upstream_state, !element_type, !element_type) {
/// %8:3 = func.call @iterators.upstream.next.0(%arg1) :
/// (!upstream_state) -> (!upstream_state, i1, !element_type)
/// scf.condition(%8#1) %8#0, %arg2, %8#2 :
/// !upstream_state, !element_type, !element_type
//// } do {
/// ^bb0(%arg1: !upstream_state, %arg2: !element_type, %arg3: !element_type):
/// %8 = func.call @accumulate_func(%arg2, %arg3) :
/// (!element_type, !element_type) -> !element_type
/// scf.yield %arg1, %8 : !upstream_state, !element_type
/// }
/// scf.yield %7#0, %7#1 : !upstream_state, !element_type
/// }
/// %3 = iterators.insertvalue %arg0[0] (%2#0 : !upstream_state) :
/// <!upstream_state, i1>
/// %true = arith.constant true
/// %4 = arith.xori %true, %1 : i1
/// %5 = iterators.insertvalue %true into %3[1] :
/// !iterators.state<!upstream_state, i1>
static llvm::SmallVector<Value, 4>
buildNextBody(AccumulateOp op, OpBuilder &builder, Value initialState,
ArrayRef<IteratorInfo> upstreamInfos, Type elementType) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);
Type i1 = b.getI1Type();

// Extract input element type.
StreamType inputStreamType = op.input().getType().cast<StreamType>();
Type inputElementType = inputStreamType.getElementType();

// Extract upstream state.
Type upstreamStateType = upstreamInfos[0].stateType;
Value initialUpstreamState = b.create<iterators::ExtractValueOp>(
upstreamStateType, initialState, b.getIndexAttr(0));

// Check if the iterator has returned an element already (since it should
// return one only in the first call to next).
Value hasReturned =
b.create<iterators::ExtractValueOp>(i1, initialState, b.getIndexAttr(1));
SmallVector<Type> ifReturnTypes{upstreamStateType, elementType};
auto ifOp = b.create<scf::IfOp>(
ifReturnTypes, hasReturned,
/*thenBuilder=*/
[&](OpBuilder &builder, Location loc) {
ImplicitLocOpBuilder b(loc, builder);

// Don't modify state; return undef element.
Value nextElement = b.create<UndefOp>(elementType);
b.create<scf::YieldOp>(ValueRange{initialUpstreamState, nextElement});
},
/*elseBuilder=*/
[&](OpBuilder &builder, Location loc) {
ImplicitLocOpBuilder b(loc, builder);

// Initialize accumulator with init value.
FuncOp initFunc = op.getInitFunc();
Value initValue = b.create<func::CallOp>(initFunc)->getResult(0);

// Create while loop.
SmallVector<Value> whileInputs = {initialUpstreamState, initValue};
SmallVector<Type> whileResultTypes = {
upstreamStateType, // Updated upstream state.
elementType, // Accumulator.
inputElementType // Element from last next call.
};
scf::WhileOp whileOp = scf::createWhileOp(
b, whileResultTypes, whileInputs,
/*beforeBuilder=*/
[&](OpBuilder &builder, Location loc,
Block::BlockArgListType args) {
ImplicitLocOpBuilder b(loc, builder);

Value upstreamState = args[0];
Value accumulator = args[1];

// Call next function.
SmallVector<Type> nextResultTypes = {upstreamStateType, i1,
inputElementType};
SymbolRefAttr nextFunc = upstreamInfos[0].nextFunc;
auto nextCall = b.create<func::CallOp>(nextFunc, nextResultTypes,
upstreamState);

Value updatedUpstreamState = nextCall->getResult(0);
Value hasNext = nextCall->getResult(1);
Value maybeNextElement = nextCall->getResult(2);
b.create<scf::ConditionOp>(
hasNext, ValueRange{updatedUpstreamState, accumulator,
maybeNextElement});
},
/*afterBuilder=*/
[&](OpBuilder &builder, Location loc,
Block::BlockArgListType args) {
ImplicitLocOpBuilder b(loc, builder);

Value upstreamState = args[0];
Value accumulator = args[1];
Value nextElement = args[2];

// Call accumulate function.
auto accumulateCall =
b.create<func::CallOp>(elementType, op.accumulateFuncRef(),
ValueRange{accumulator, nextElement});
Value newAccumulator = accumulateCall->getResult(0);

b.create<scf::YieldOp>(ValueRange{upstreamState, newAccumulator});
});

Value updatedState = whileOp->getResult(0);
Value accumulator = whileOp->getResult(1);

b.create<scf::YieldOp>(ValueRange{updatedState, accumulator});
});

// Compute hasNext: we have an element iff we have not returned before, i.e.,
// iff "not hasReturend". We simulate "not" with "xor true".
Value constTrue = b.create<arith::ConstantIntOp>(/*value=*/1, /*width=*/1);
Value hasNext = b.create<arith::XOrIOp>(constTrue, hasReturned);

// Update state.
Value finalUpstreamState = ifOp->getResult(0);
Value finalState = b.create<iterators::InsertValueOp>(
initialState, b.getIndexAttr(0), finalUpstreamState);
finalState = b.create<iterators::InsertValueOp>(finalState, b.getIndexAttr(1),
constTrue);
Value nextElement = ifOp->getResult(1);

return {finalState, hasNext, nextElement};
}

/// Builds IR that closes the nested upstream iterator. Possible output:
///
/// %0 = iterators.extractvalue %arg0[0] :
/// !iterators.state<!upstream_state, i1> -> !upstream_state
/// %1 = call @iterators.upstream.close.0(%0) :
/// (!upstream_state) -> !upstream_state
/// %2 = iterators.insertvalue %arg0[0] (%1 : !upstream_state) :
/// !iterators.state<!upstream_state, i1>
static Value buildCloseBody(AccumulateOp op, OpBuilder &builder,
Value initialState,
ArrayRef<IteratorInfo> upstreamInfos) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);

Type upstreamStateType = upstreamInfos[0].stateType;

// Extract upstream state.
Value initialUpstreamState = b.create<iterators::ExtractValueOp>(
upstreamStateType, initialState, b.getIndexAttr(0));

// Call Close on upstream.
SymbolRefAttr closeFunc = upstreamInfos[0].closeFunc;
auto closeCallOp = b.create<func::CallOp>(closeFunc, upstreamStateType,
initialUpstreamState);

// Update upstream state.
Value updatedUpstreamState = closeCallOp->getResult(0);
return b
.create<iterators::InsertValueOp>(initialState, b.getIndexAttr(0),
updatedUpstreamState)
.getResult();
}

/// Builds IR that initializes the iterator state with the state of the upstream
/// iterator. Possible output:
///
/// %0 = ...
/// %1 = iterators.undefstate : <!upstream_state, i1>
/// %2 = iterators.insertvalue %1[0] (%0 : !upstream_state) :
/// !iterators.state<!upstream_state, i1>
static Value buildStateCreation(AccumulateOp op, AccumulateOp::Adaptor adaptor,
OpBuilder &builder, StateType stateType) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);
Value undefState = b.create<UndefStateOp>(loc, stateType);
Value upstreamState = adaptor.input();
return b.create<iterators::InsertValueOp>(undefState, b.getIndexAttr(0),
upstreamState);
}

//===----------------------------------------------------------------------===//
// ConstantStreamOp.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1298,6 +1545,7 @@ static Value buildOpenBody(Operation *op, OpBuilder &builder,
return llvm::TypeSwitch<Operation *, Value>(op)
.Case<
// clang-format off
AccumulateOp,
ConstantStreamOp,
FilterOp,
MapOp,
Expand All @@ -1317,6 +1565,7 @@ buildNextBody(Operation *op, OpBuilder &builder, Value initialState,
return llvm::TypeSwitch<Operation *, llvm::SmallVector<Value, 4>>(op)
.Case<
// clang-format off
AccumulateOp,
ConstantStreamOp,
FilterOp,
MapOp,
Expand All @@ -1337,6 +1586,7 @@ static Value buildCloseBody(Operation *op, OpBuilder &builder,
return llvm::TypeSwitch<Operation *, Value>(op)
.Case<
// clang-format off
AccumulateOp,
ConstantStreamOp,
FilterOp,
MapOp,
Expand All @@ -1355,6 +1605,7 @@ static Value buildStateCreation(IteratorOpInterface op, OpBuilder &builder,
return llvm::TypeSwitch<Operation *, Value>(op)
.Case<
// clang-format off
AccumulateOp,
ConstantStreamOp,
FilterOp,
MapOp,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// RUN: mlir-proto-opt %s -convert-iterators-to-llvm \
// RUN: | FileCheck --enable-var-scope %s

!element_type = !llvm.struct<(i32)>

func.func private @zero_struct() -> !element_type {
%zero = arith.constant 0 : i32
%undef = llvm.mlir.undef : !element_type
%result = llvm.insertvalue %zero, %undef[0 : index] : !element_type
return %result : !element_type
}

func.func private @sum_struct(%lhs : !element_type, %rhs : !element_type) -> !element_type {
%lhsi = llvm.extractvalue %lhs[0 : index] : !element_type
%rhsi = llvm.extractvalue %rhs[0 : index] : !element_type
%i = arith.addi %lhsi, %rhsi : i32
%result = llvm.insertvalue %i, %lhs[0 : index] : !element_type
return %result : !element_type
}

// CHECK-LABEL: func.func private @iterators.accumulate.next.{{[0-9]+}}(%{{.*}}: !iterators.state<!iterators.state<i32>, i1>) -> (!iterators.state<!iterators.state<i32>, i1>, i1, !llvm.struct<(i32)>) {
// CHECK-NEXT: %[[V1:.*]] = iterators.extractvalue %[[arg0:.*]][0] : !iterators.state<!iterators.state<i32>, i1>
// CHECK-NEXT: %[[V2:.*]] = iterators.extractvalue %[[arg0]][1] : !iterators.state<!iterators.state<i32>, i1>
// CHECK-NEXT: %[[V3:.*]]:2 = scf.if %[[V2]] -> (!iterators.state<i32>, !llvm.struct<(i32)>) {
// CHECK-NEXT: %[[V4:.*]] = llvm.mlir.undef : !llvm.struct<(i32)>
// CHECK-NEXT: scf.yield %[[V1]], %[[V4]] : !iterators.state<i32>, !llvm.struct<(i32)>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[V4:.*]] = func.call @zero_struct() : () -> !llvm.struct<(i32)>
// CHECK-NEXT: %[[V5:.*]]:3 = scf.while (%[[arg1:.*]] = %[[V1]], %[[arg2:.*]] = %[[V4]]) : (!iterators.state<i32>, !llvm.struct<(i32)>) -> (!iterators.state<i32>, !llvm.struct<(i32)>, !llvm.struct<(i32)>) {
// CHECK-NEXT: %[[V6:.*]]:3 = func.call @iterators.constantstream.next.0(%[[arg1]]) : (!iterators.state<i32>) -> (!iterators.state<i32>, i1, !llvm.struct<(i32)>)
// CHECK-NEXT: scf.condition(%[[V6]]#1) %[[V6]]#0, %[[arg2]], %[[V6]]#2 : !iterators.state<i32>, !llvm.struct<(i32)>, !llvm.struct<(i32)>
// CHECK-NEXT: } do {
// CHECK-NEXT: ^bb0(%[[arg1:.*]]: !iterators.state<i32>, %[[arg2:.*]]: !llvm.struct<(i32)>, %[[arg3:.*]]: !llvm.struct<(i32)>):
// CHECK-NEXT: %[[V7:.*]] = func.call @sum_struct(%[[arg2]], %[[arg3]]) : (!llvm.struct<(i32)>, !llvm.struct<(i32)>) -> !llvm.struct<(i32)>
// CHECK-NEXT: scf.yield %[[arg1]], %[[V7]] : !iterators.state<i32>, !llvm.struct<(i32)>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[V5]]#0, %[[V5]]#1 : !iterators.state<i32>, !llvm.struct<(i32)>
// CHECK-NEXT: }
// CHECK-NEXT: %[[V8:.*]] = arith.constant true
// CHECK-NEXT: %[[V9:.*]] = arith.xori %[[V8]], %[[V2]] : i1
// CHECK-NEXT: %[[Va:.*]] = iterators.insertvalue %[[V3]]#0 into %[[arg0]][0] : !iterators.state<!iterators.state<i32>, i1>
// CHECK-NEXT: %[[Vb:.*]] = iterators.insertvalue %[[V8]] into %[[Va]][1] : !iterators.state<!iterators.state<i32>, i1>
// CHECK-NEXT: return %[[Vb]], %[[V9]], %[[V3]]#1 : !iterators.state<!iterators.state<i32>, i1>, i1, !llvm.struct<(i32)>
// CHECK-NEXT: }

// CHECK-LABEL: func.func private @iterators.accumulate.open.{{[0-9]+}}(%{{.*}}: !iterators.state<!iterators.state<i32>, i1>) -> !iterators.state<!iterators.state<i32>, i1> {
// CHECK-NEXT: %[[V1:.*]] = iterators.extractvalue %[[arg0:.*]][0] : !iterators.state<!iterators.state<i32>, i1>
// CHECK-NEXT: %[[V2:.*]] = call @iterators.constantstream.open.0(%[[V1]]) : (!iterators.state<i32>) -> !iterators.state<i32>
// CHECK-NEXT: %[[V3:.*]] = iterators.insertvalue %[[V2]] into %[[arg0]][0] : !iterators.state<!iterators.state<i32>, i1>
// CHECK-NEXT: %[[V4:.*]] = arith.constant false
// CHECK-NEXT: %[[V5:.*]] = iterators.insertvalue %[[V4]] into %[[V3]][1] : !iterators.state<!iterators.state<i32>, i1>
// CHECK-NEXT: return %[[V5]] : !iterators.state<!iterators.state<i32>, i1>
// CHECK-NEXT: }

func.func @main() {
// CHECK-LABEL: func.func @main()
%input = "iterators.constantstream"() { value = [] } : () -> (!iterators.stream<!element_type>)
%accumulated = iterators.accumulate(%input, @zero_struct, @sum_struct)
: (!iterators.stream<!element_type>) -> !iterators.stream<!element_type>
// CHECK: %[[V1:.*]] = iterators.undefstate : !iterators.state<!iterators.state<i32>, i1>
// CHECK-NEXT: %[[V2:.*]] = iterators.insertvalue %[[V0:.*]] into %[[V1]][0] : !iterators.state<!iterators.state<i32>, i1>
return
// CHECK-NEXT: return
}
Loading

0 comments on commit fca969e

Please sign in to comment.