Skip to content

Commit

Permalink
Remove use of op builders with result type as an argument
Browse files Browse the repository at this point in the history
We would like to use mhlo op builders that compute the result type based on the arguments so that in the future, bounds are appropriately propagated.

PiperOrigin-RevId: 448600328
  • Loading branch information
smit-hinsu authored and tensorflower-gardener committed May 13, 2022
1 parent cadf655 commit 5ea8ba1
Showing 1 changed file with 41 additions and 65 deletions.
106 changes: 41 additions & 65 deletions tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,8 @@ static Value DynamicSliceInMinorDims(Location loc, Value v,
auto slice_sizes = llvm::to_vector<4>(type.getShape());
std::copy(minor_sizes.begin(), minor_sizes.end(),
slice_sizes.begin() + major_dims);
auto slice_type = RankedTensorType::get(slice_sizes, type.getElementType());
return builder->create<mhlo::DynamicSliceOp>(
loc, slice_type, v, slice_starts,
GetI64ElementsAttr(slice_sizes, builder));
loc, v, slice_starts, GetI64ElementsAttr(slice_sizes, builder));
}

// Creates an mhlo::DynamicUpdateSliceOp where the major dimensions have zero
Expand Down Expand Up @@ -484,8 +482,6 @@ static Value ApplyReduction(Location loc, Value input,
static mhlo::RngUniformOp CreateRngUniform32(Location loc, int num_elements,
int lower_limit, int upper_limit,
OpBuilder *builder) {
auto i32_type = builder->getIntegerType(32);
auto key_type = RankedTensorType::get({num_elements}, i32_type);
auto shape_tensor = builder->create<mhlo::ConstOp>(
loc, GetI64ElementsAttr({num_elements}, builder));

Expand All @@ -494,8 +490,7 @@ static mhlo::RngUniformOp CreateRngUniform32(Location loc, int num_elements,
auto upper = builder->create<mhlo::ConstOp>(
loc, builder->getI32IntegerAttr(upper_limit));

return builder->create<mhlo::RngUniformOp>(loc, key_type, lower, upper,
shape_tensor);
return builder->create<mhlo::RngUniformOp>(loc, lower, upper, shape_tensor);
}

using WhileBodyFnType = llvm::function_ref<void(
Expand Down Expand Up @@ -1448,9 +1443,8 @@ class ConvertPadOpDynamic : public OpRewritePattern<TF::PadV2Op> {
}
llvm::SmallVector<int64_t, 2> transposed_shape = {2, input_rank};
auto transpose_attr = GetI64ElementsAttr({1, 0}, &rewriter);
Value transposed_paddings = rewriter.create<mhlo::TransposeOp>(
loc, RankedTensorType::get(transposed_shape, paddings_elem_ty),
paddings, transpose_attr);
Value transposed_paddings =
rewriter.create<mhlo::TransposeOp>(loc, paddings, transpose_attr);
Value reshaped_paddings = rewriter.create<mhlo::ReshapeOp>(
loc, RankedTensorType::get({input_rank * 2}, paddings_elem_ty),
transposed_paddings);
Expand Down Expand Up @@ -1708,21 +1702,20 @@ class ConvertLeakyReluOp : public OpRewritePattern<TF::LeakyReluOp> {
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value features = op.features();
auto featureType = features.getType();

// Use ConstantLike for `alpha` to match the shape of feature.
auto alphaVal = chlo::getConstantLike(
rewriter, loc, op.alpha().convertToFloat(), features);
Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features);

Value leakyActivationVal = rewriter.create<mhlo::MulOp>(
loc, features.getType(), features, alphaVal);
Value leakyActivationVal =
rewriter.create<mhlo::MulOp>(loc, features, alphaVal);

Value compareGtZero = rewriter.create<mhlo::CompareOp>(
loc, features, zeroVal, ComparisonDirection::GT);

rewriter.replaceOpWithNewOp<SelectOp>(op, featureType, compareGtZero,
features, leakyActivationVal);
rewriter.replaceOpWithNewOp<SelectOp>(op, compareGtZero, features,
leakyActivationVal);
return success();
}
};
Expand All @@ -1745,8 +1738,8 @@ class ConvertLeakyReluGradOp : public OpRewritePattern<TF::LeakyReluGradOp> {
rewriter, loc, op.alpha().convertToFloat(), features);
Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features);

Value leakyGradientVal = rewriter.create<mhlo::MulOp>(
loc, features.getType(), gradients, alphaVal);
Value leakyGradientVal =
rewriter.create<mhlo::MulOp>(loc, gradients, alphaVal);

Value compareGtZero = rewriter.create<mhlo::CompareOp>(
loc, features, zeroVal, ComparisonDirection::GT);
Expand Down Expand Up @@ -1872,7 +1865,6 @@ class ConvertMatrixDiagPartV3Op
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
ShapedType input_type = op.input().getType().dyn_cast<ShapedType>();
auto element_type = input_type.getElementType();

// Align is a string specifying how superdiagonals and subdiagonals should
// be aligned/padded for diagonals that are shorter than max_diag_len. The
Expand Down Expand Up @@ -2031,7 +2023,6 @@ class ConvertMatrixDiagPartV3Op
}
output_shape.push_back(num_diags);
output_shape.push_back(max_diag_len);
auto output_type = RankedTensorType::get(output_shape, element_type);

// A slice is the shape of what GatherOp copies per lookup. So the last
// two dimensions (M, N in the matrix-diag-part docs) are where we go
Expand All @@ -2058,7 +2049,7 @@ class ConvertMatrixDiagPartV3Op
/*collapsed_slice_dims=*/collapsed_dims, start_index_map,
/*index_vector_dim=*/0);
Value gather = rewriter.create<mhlo::GatherOp>(
loc, output_type, op.input(), start_indices, dims_attr,
loc, op.input(), start_indices, dims_attr,
GetI64ElementsAttr(slice_sizes, &rewriter));

// We now need to broadcast the "in_bounds" boolean expression, as well as
Expand All @@ -2071,12 +2062,11 @@ class ConvertMatrixDiagPartV3Op
loc, RankedTensorType::get(output_shape, rewriter.getIntegerType(1)),
in_bounds, GetI64ElementsAttr(broadcast_bounds, &rewriter));
Value b_padding = rewriter.create<BroadcastOp>(
loc, output_type, op.padding_value(),
GetI64ElementsAttr(output_shape, &rewriter));
loc, op.padding_value(), GetI64ElementsAttr(output_shape, &rewriter));

// Replace all out-of-bounds values in the result with padding_value.
Value result = rewriter.create<SelectOp>(loc, output_type, b_in_bounds,
gather, b_padding);
Value result =
rewriter.create<SelectOp>(loc, b_in_bounds, gather, b_padding);

if (num_diags == 1) {
// matrix_diag_part folds away the 1-sized band dimension if we only
Expand Down Expand Up @@ -5870,8 +5860,6 @@ class ConvertRandomShuffleOp : public OpRewritePattern<TF::RandomShuffleOp> {
Value swaps = old_values[0];
Value indices = old_values[1];

auto vec1_i32_type =
RankedTensorType::get({1}, builder->getIntegerType(32));
auto scalar_i32_type =
RankedTensorType::get({}, builder->getIntegerType(32));
auto scalar_i64_type =
Expand All @@ -5882,14 +5870,13 @@ class ConvertRandomShuffleOp : public OpRewritePattern<TF::RandomShuffleOp> {

// We need to swap the indices[i] with indices[swaps[i]]. First get
// these index values.
Value source_index = builder->create<mhlo::DynamicSliceOp>(
loc, vec1_i32_type, indices, i, scalar_one);
Value source_index =
builder->create<mhlo::DynamicSliceOp>(loc, indices, i, scalar_one);
Value swap_index = builder->create<mhlo::ReshapeOp>(
loc, scalar_i32_type,
builder->create<mhlo::DynamicSliceOp>(loc, vec1_i32_type, swaps, i,
scalar_one));
builder->create<mhlo::DynamicSliceOp>(loc, swaps, i, scalar_one));
Value target_index = builder->create<mhlo::DynamicSliceOp>(
loc, vec1_i32_type, indices, swap_index, scalar_one);
loc, indices, swap_index, scalar_one);

// Then perform the swap.
// indices[i] <- indices[swaps[i]]
Expand Down Expand Up @@ -6254,8 +6241,7 @@ class ConvertCumOp : public OpRewritePattern<OpT> {
if (op.reverse()) {
llvm::SmallVector<int64_t, 4> dims_to_reverse({axis});
input = rewriter.create<ReverseOp>(
op.getLoc(), op.getType(), input,
GetI64ElementsAttr(dims_to_reverse, &rewriter));
op.getLoc(), input, GetI64ElementsAttr(dims_to_reverse, &rewriter));
}

// Convert if we need to enlarge the element type's bitwidth to avoid
Expand Down Expand Up @@ -6303,8 +6289,7 @@ class ConvertCumOp : public OpRewritePattern<OpT> {
low_padding[axis] = 1;
high_padding[axis] = -1;
result = rewriter.create<PadOp>(
op.getLoc(), result.getType(), result, init,
GetI64ElementsAttr(low_padding, &rewriter),
op.getLoc(), result, init, GetI64ElementsAttr(low_padding, &rewriter),
GetI64ElementsAttr(high_padding, &rewriter),
GetI64ElementsAttr(interior_padding, &rewriter));
}
Expand All @@ -6316,8 +6301,7 @@ class ConvertCumOp : public OpRewritePattern<OpT> {
if (op.reverse()) {
llvm::SmallVector<int64_t, 4> dims_to_reverse({axis});
result = rewriter.create<ReverseOp>(
op.getLoc(), op.getType(), result,
GetI64ElementsAttr(dims_to_reverse, &rewriter));
op.getLoc(), result, GetI64ElementsAttr(dims_to_reverse, &rewriter));
}

rewriter.replaceOp(op, result);
Expand Down Expand Up @@ -6500,9 +6484,9 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
rewriter.create<ConvertOp>(op.getLoc(), compare, type.getElementType());
auto q_shape = llvm::to_vector<4>(type.getShape());
q_shape.back() = m;
Value q = rewriter.create<BroadcastOp>(
op.getLoc(), RankedTensorType::get(q_shape, type.getElementType()),
identity_matrix, GetI64ElementsAttr(batch_dims, &rewriter));
Value q =
rewriter.create<BroadcastOp>(op.getLoc(), identity_matrix,
GetI64ElementsAttr(batch_dims, &rewriter));
auto precision_config = rewriter.getArrayAttr(
{PrecisionAttr::get(rewriter.getContext(), Precision::HIGHEST),
PrecisionAttr::get(rewriter.getContext(), Precision::HIGHEST)});
Expand Down Expand Up @@ -6625,32 +6609,29 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
loc, alpha, zero, GetI64ElementsAttr({}, builder),
ComparisonDirection::LT);
auto batch_size_one = builder->create<BroadcastOp>(
loc, alpha.getType(), one, GetI64ElementsAttr(batch_dims, builder));
loc, one, GetI64ElementsAttr(batch_dims, builder));
Value signed_mu = builder->create<chlo::BroadcastMulOp>(
loc,
builder->create<SelectOp>(loc, mu.getType(), alpha_is_negative,
batch_size_one,
builder->create<SelectOp>(loc, alpha_is_negative, batch_size_one,
builder->create<NegOp>(loc, batch_size_one)),
mu, GetI64ElementsAttr({}, builder));
*beta = builder->create<SelectOp>(loc, alpha.getType(), sigma_is_zero,
alpha, signed_mu);
*beta = builder->create<SelectOp>(loc, sigma_is_zero, alpha, signed_mu);
*tau = builder->create<DivOp>(
loc, builder->create<SubOp>(loc, *beta, alpha), *beta);
Value zero_tau = builder->create<BroadcastOp>(
loc, alpha.getType(), zero, GetI64ElementsAttr(batch_dims, builder));
*tau = builder->create<SelectOp>(loc, alpha.getType(), sigma_is_zero,
zero_tau, *tau);
loc, zero, GetI64ElementsAttr(batch_dims, builder));
*tau = builder->create<SelectOp>(loc, sigma_is_zero, zero_tau, *tau);
Value divisor = builder->create<SubOp>(loc, alpha, *beta);
divisor = builder->create<SelectOp>(loc, divisor.getType(), sigma_is_zero,
batch_size_one, divisor);
divisor =
builder->create<SelectOp>(loc, sigma_is_zero, batch_size_one, divisor);

Value eqk = builder->create<chlo::BroadcastCompareOp>(
loc, iota, k, GetI64ElementsAttr({}, builder), ComparisonDirection::EQ);
eqk = builder->create<ConvertOp>(loc, eqk, x_type.getElementType());
llvm::SmallVector<int64_t, 4> e_k_shape(batch_dims.size(), 1);
e_k_shape.push_back(m);
auto e_k = builder->create<BroadcastOp>(
loc, RankedTensorType::get(e_k_shape, x_type.getElementType()), eqk,
loc, eqk,
GetI64ElementsAttr(llvm::SmallVector<int64_t, 4>(batch_dims.size(), 1),
builder));

Expand Down Expand Up @@ -6758,11 +6739,8 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
loc, iota, j, GetI64ElementsAttr({}, builder),
ComparisonDirection::EQ);
mask = builder->create<ConvertOp>(loc, mask, a_type.getElementType());
llvm::SmallVector<int64_t, 4> broadcast_mask_shape(a_type.getRank(), 1);
broadcast_mask_shape[a_type.getRank() - 2] = m;
mask = builder->create<BroadcastOp>(
loc,
RankedTensorType::get(broadcast_mask_shape, a_type.getElementType()),
mask,
GetI64ElementsAttr(llvm::SmallVector<int64_t, 4>(num_batch_dims, 1),
builder));
Expand All @@ -6787,19 +6765,19 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
Value xa_mask = builder->create<chlo::BroadcastCompareOp>(
loc, iota_mn, j, GetI64ElementsAttr({}, builder),
ComparisonDirection::EQ);
a = builder->create<SelectOp>(loc, a_type, xa_mask, new_x, a);
a = builder->create<SelectOp>(loc, xa_mask, new_x, a);

// vs[:, j] = v
llvm::SmallVector<int64_t, 4> vs_broadcast_dims(num_batch_dims + 1);
std::iota(vs_broadcast_dims.begin(), vs_broadcast_dims.end(), 0);
Value vs_zeros =
GetScalarConstOfType(a_type.getElementType(), loc, 0, builder);
vs_zeros = builder->create<BroadcastOp>(
loc, vs.getType(), vs_zeros,
loc, vs_zeros,
GetI64ElementsAttr(vs.getType().cast<RankedTensorType>().getShape(),
builder));
auto vs_update = builder->create<SelectOp>(
loc, vs.getType(), xa_mask,
loc, xa_mask,
StaticBinaryBroadcast<AddOp>(
loc, vs_zeros, v, GetI64ElementsAttr(vs_broadcast_dims, builder),
*builder),
Expand All @@ -6818,14 +6796,14 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
Value taus_zeros =
GetScalarConstOfType(a_type.getElementType(), loc, 0, builder);
taus_zeros = builder->create<BroadcastOp>(
loc, taus.getType(), taus_zeros,
loc, taus_zeros,
GetI64ElementsAttr(taus.getType().cast<RankedTensorType>().getShape(),
builder));
Value taus_mask = builder->create<chlo::BroadcastCompareOp>(
loc, iota_n, j, GetI64ElementsAttr({}, builder),
ComparisonDirection::EQ);
auto taus_update = builder->create<SelectOp>(
loc, taus.getType(), taus_mask,
loc, taus_mask,
StaticBinaryBroadcast<AddOp>(
loc, taus_zeros, tau,
GetI64ElementsAttr(tau_broadcast_dims, builder), *builder),
Expand All @@ -6837,12 +6815,11 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
Value zero =
GetScalarConstOfType(a_type.getElementType(), loc, 0, rewriter);
*vs = rewriter->create<BroadcastOp>(
loc, a_type, zero, GetI64ElementsAttr(a_type.getShape(), rewriter));
loc, zero, GetI64ElementsAttr(a_type.getShape(), rewriter));
auto taus_shape = llvm::to_vector<4>(batch_dims);
taus_shape.push_back(n);
*taus = rewriter->create<BroadcastOp>(
loc, RankedTensorType::get(taus_shape, a_type.getElementType()), zero,
GetI64ElementsAttr(taus_shape, rewriter));
loc, zero, GetI64ElementsAttr(taus_shape, rewriter));

SmallVector<Value, 4> while_output;
CreateWhile32(loc, std::min(m, n), qr_body_fn, {a, *vs, *taus},
Expand Down Expand Up @@ -6903,13 +6880,13 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
Value zero = GetScalarConstOfType(getElementTypeOrSelf(vs.getType()), loc,
0, builder);
zero = builder->create<BroadcastOp>(
loc, vs.getType(), zero,
loc, zero,
GetI64ElementsAttr(vs.getType().cast<RankedTensorType>().getShape(),
builder));
auto compare = builder->create<chlo::BroadcastCompareOp>(
loc, iota_mn, j, GetI64ElementsAttr({}, builder),
ComparisonDirection::GE);
auto y = builder->create<SelectOp>(loc, vs.getType(), compare, zero, vs);
auto y = builder->create<SelectOp>(loc, compare, zero, vs);

// yv has shape [..., n, 1]
auto precision = builder->getArrayAttr(
Expand Down Expand Up @@ -6939,7 +6916,6 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
auto w_shape = llvm::to_vector<4>(batch_dims);
w_shape.append({m, n});
w = rewriter->create<BroadcastOp>(loc,
RankedTensorType::get(w_shape, data_type),
w, GetI64ElementsAttr(w_shape, rewriter));
auto v = SliceInMinorDims(loc, vs, {0}, {1}, rewriter);
auto beta = SliceInMinorDims(loc, taus, {0}, {1}, rewriter);
Expand Down

0 comments on commit 5ea8ba1

Please sign in to comment.