From 5ea8ba100c58e8951e8e380248a6f9466f63b0ab Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Fri, 13 May 2022 16:26:21 -0700 Subject: [PATCH] Remove use of op builders with result type as an argument We would like to use mhlo op builders that compute the result type based on the arguments so that in the future, bounds are appropriately propagated. PiperOrigin-RevId: 448600328 --- .../mlir/xla/transforms/legalize_tf.cc | 106 +++++++----------- 1 file changed, 41 insertions(+), 65 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 4fd119f8816fbe..72d461f38f6c4d 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -267,10 +267,8 @@ static Value DynamicSliceInMinorDims(Location loc, Value v, auto slice_sizes = llvm::to_vector<4>(type.getShape()); std::copy(minor_sizes.begin(), minor_sizes.end(), slice_sizes.begin() + major_dims); - auto slice_type = RankedTensorType::get(slice_sizes, type.getElementType()); return builder->create( - loc, slice_type, v, slice_starts, - GetI64ElementsAttr(slice_sizes, builder)); + loc, v, slice_starts, GetI64ElementsAttr(slice_sizes, builder)); } // Creates an mhlo::DynamicUpdateSliceOp where the major dimensions have zero @@ -484,8 +482,6 @@ static Value ApplyReduction(Location loc, Value input, static mhlo::RngUniformOp CreateRngUniform32(Location loc, int num_elements, int lower_limit, int upper_limit, OpBuilder *builder) { - auto i32_type = builder->getIntegerType(32); - auto key_type = RankedTensorType::get({num_elements}, i32_type); auto shape_tensor = builder->create( loc, GetI64ElementsAttr({num_elements}, builder)); @@ -494,8 +490,7 @@ static mhlo::RngUniformOp CreateRngUniform32(Location loc, int num_elements, auto upper = builder->create( loc, builder->getI32IntegerAttr(upper_limit)); - return builder->create(loc, key_type, lower, upper, - shape_tensor); + return builder->create(loc, lower, upper, shape_tensor); } using WhileBodyFnType = llvm::function_ref { } llvm::SmallVector transposed_shape = {2, input_rank}; auto transpose_attr = GetI64ElementsAttr({1, 0}, &rewriter); - Value transposed_paddings = rewriter.create( - loc, RankedTensorType::get(transposed_shape, paddings_elem_ty), - paddings, transpose_attr); + Value transposed_paddings = + rewriter.create(loc, paddings, transpose_attr); Value reshaped_paddings = rewriter.create( loc, RankedTensorType::get({input_rank * 2}, paddings_elem_ty), transposed_paddings); @@ -1708,21 +1702,20 @@ class ConvertLeakyReluOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value features = op.features(); - auto featureType = features.getType(); // Use ConstantLike for `alpha` to match the shape of feature. auto alphaVal = chlo::getConstantLike( rewriter, loc, op.alpha().convertToFloat(), features); Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features); - Value leakyActivationVal = rewriter.create( - loc, features.getType(), features, alphaVal); + Value leakyActivationVal = + rewriter.create(loc, features, alphaVal); Value compareGtZero = rewriter.create( loc, features, zeroVal, ComparisonDirection::GT); - rewriter.replaceOpWithNewOp(op, featureType, compareGtZero, - features, leakyActivationVal); + rewriter.replaceOpWithNewOp(op, compareGtZero, features, + leakyActivationVal); return success(); } }; @@ -1745,8 +1738,8 @@ class ConvertLeakyReluGradOp : public OpRewritePattern { rewriter, loc, op.alpha().convertToFloat(), features); Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features); - Value leakyGradientVal = rewriter.create( - loc, features.getType(), gradients, alphaVal); + Value leakyGradientVal = + rewriter.create(loc, gradients, alphaVal); Value compareGtZero = rewriter.create( loc, features, zeroVal, ComparisonDirection::GT); @@ -1872,7 +1865,6 @@ class ConvertMatrixDiagPartV3Op PatternRewriter &rewriter) const override { Location loc = op.getLoc(); ShapedType input_type = op.input().getType().dyn_cast(); - auto element_type = input_type.getElementType(); // Align is a string specifying how superdiagonals and subdiagonals should // be aligned/padded for diagonals that are shorter than max_diag_len. The @@ -2031,7 +2023,6 @@ class ConvertMatrixDiagPartV3Op } output_shape.push_back(num_diags); output_shape.push_back(max_diag_len); - auto output_type = RankedTensorType::get(output_shape, element_type); // A slice is the shape of what GatherOp copies per lookup. So the last // two dimensions (M, N in the matrix-diag-part docs) are where we go @@ -2058,7 +2049,7 @@ class ConvertMatrixDiagPartV3Op /*collapsed_slice_dims=*/collapsed_dims, start_index_map, /*index_vector_dim=*/0); Value gather = rewriter.create( - loc, output_type, op.input(), start_indices, dims_attr, + loc, op.input(), start_indices, dims_attr, GetI64ElementsAttr(slice_sizes, &rewriter)); // We now need to broadcast the "in_bounds" boolean expression, as well as @@ -2071,12 +2062,11 @@ class ConvertMatrixDiagPartV3Op loc, RankedTensorType::get(output_shape, rewriter.getIntegerType(1)), in_bounds, GetI64ElementsAttr(broadcast_bounds, &rewriter)); Value b_padding = rewriter.create( - loc, output_type, op.padding_value(), - GetI64ElementsAttr(output_shape, &rewriter)); + loc, op.padding_value(), GetI64ElementsAttr(output_shape, &rewriter)); // Replace all out-of-bounds values in the result with padding_value. - Value result = rewriter.create(loc, output_type, b_in_bounds, - gather, b_padding); + Value result = + rewriter.create(loc, b_in_bounds, gather, b_padding); if (num_diags == 1) { // matrix_diag_part folds away the 1-sized band dimension if we only @@ -5870,8 +5860,6 @@ class ConvertRandomShuffleOp : public OpRewritePattern { Value swaps = old_values[0]; Value indices = old_values[1]; - auto vec1_i32_type = - RankedTensorType::get({1}, builder->getIntegerType(32)); auto scalar_i32_type = RankedTensorType::get({}, builder->getIntegerType(32)); auto scalar_i64_type = @@ -5882,14 +5870,13 @@ class ConvertRandomShuffleOp : public OpRewritePattern { // We need to swap the indices[i] with indices[swaps[i]]. First get // these index values. - Value source_index = builder->create( - loc, vec1_i32_type, indices, i, scalar_one); + Value source_index = + builder->create(loc, indices, i, scalar_one); Value swap_index = builder->create( loc, scalar_i32_type, - builder->create(loc, vec1_i32_type, swaps, i, - scalar_one)); + builder->create(loc, swaps, i, scalar_one)); Value target_index = builder->create( - loc, vec1_i32_type, indices, swap_index, scalar_one); + loc, indices, swap_index, scalar_one); // Then perform the swap. // indices[i] <- indices[swaps[i]] @@ -6254,8 +6241,7 @@ class ConvertCumOp : public OpRewritePattern { if (op.reverse()) { llvm::SmallVector dims_to_reverse({axis}); input = rewriter.create( - op.getLoc(), op.getType(), input, - GetI64ElementsAttr(dims_to_reverse, &rewriter)); + op.getLoc(), input, GetI64ElementsAttr(dims_to_reverse, &rewriter)); } // Convert if we need to enlarge the element type's bitwidth to avoid @@ -6303,8 +6289,7 @@ class ConvertCumOp : public OpRewritePattern { low_padding[axis] = 1; high_padding[axis] = -1; result = rewriter.create( - op.getLoc(), result.getType(), result, init, - GetI64ElementsAttr(low_padding, &rewriter), + op.getLoc(), result, init, GetI64ElementsAttr(low_padding, &rewriter), GetI64ElementsAttr(high_padding, &rewriter), GetI64ElementsAttr(interior_padding, &rewriter)); } @@ -6316,8 +6301,7 @@ class ConvertCumOp : public OpRewritePattern { if (op.reverse()) { llvm::SmallVector dims_to_reverse({axis}); result = rewriter.create( - op.getLoc(), op.getType(), result, - GetI64ElementsAttr(dims_to_reverse, &rewriter)); + op.getLoc(), result, GetI64ElementsAttr(dims_to_reverse, &rewriter)); } rewriter.replaceOp(op, result); @@ -6500,9 +6484,9 @@ class ConvertQrOp : public OpRewritePattern { rewriter.create(op.getLoc(), compare, type.getElementType()); auto q_shape = llvm::to_vector<4>(type.getShape()); q_shape.back() = m; - Value q = rewriter.create( - op.getLoc(), RankedTensorType::get(q_shape, type.getElementType()), - identity_matrix, GetI64ElementsAttr(batch_dims, &rewriter)); + Value q = + rewriter.create(op.getLoc(), identity_matrix, + GetI64ElementsAttr(batch_dims, &rewriter)); auto precision_config = rewriter.getArrayAttr( {PrecisionAttr::get(rewriter.getContext(), Precision::HIGHEST), PrecisionAttr::get(rewriter.getContext(), Precision::HIGHEST)}); @@ -6625,24 +6609,21 @@ class ConvertQrOp : public OpRewritePattern { loc, alpha, zero, GetI64ElementsAttr({}, builder), ComparisonDirection::LT); auto batch_size_one = builder->create( - loc, alpha.getType(), one, GetI64ElementsAttr(batch_dims, builder)); + loc, one, GetI64ElementsAttr(batch_dims, builder)); Value signed_mu = builder->create( loc, - builder->create(loc, mu.getType(), alpha_is_negative, - batch_size_one, + builder->create(loc, alpha_is_negative, batch_size_one, builder->create(loc, batch_size_one)), mu, GetI64ElementsAttr({}, builder)); - *beta = builder->create(loc, alpha.getType(), sigma_is_zero, - alpha, signed_mu); + *beta = builder->create(loc, sigma_is_zero, alpha, signed_mu); *tau = builder->create( loc, builder->create(loc, *beta, alpha), *beta); Value zero_tau = builder->create( - loc, alpha.getType(), zero, GetI64ElementsAttr(batch_dims, builder)); - *tau = builder->create(loc, alpha.getType(), sigma_is_zero, - zero_tau, *tau); + loc, zero, GetI64ElementsAttr(batch_dims, builder)); + *tau = builder->create(loc, sigma_is_zero, zero_tau, *tau); Value divisor = builder->create(loc, alpha, *beta); - divisor = builder->create(loc, divisor.getType(), sigma_is_zero, - batch_size_one, divisor); + divisor = + builder->create(loc, sigma_is_zero, batch_size_one, divisor); Value eqk = builder->create( loc, iota, k, GetI64ElementsAttr({}, builder), ComparisonDirection::EQ); @@ -6650,7 +6631,7 @@ class ConvertQrOp : public OpRewritePattern { llvm::SmallVector e_k_shape(batch_dims.size(), 1); e_k_shape.push_back(m); auto e_k = builder->create( - loc, RankedTensorType::get(e_k_shape, x_type.getElementType()), eqk, + loc, eqk, GetI64ElementsAttr(llvm::SmallVector(batch_dims.size(), 1), builder)); @@ -6758,11 +6739,8 @@ class ConvertQrOp : public OpRewritePattern { loc, iota, j, GetI64ElementsAttr({}, builder), ComparisonDirection::EQ); mask = builder->create(loc, mask, a_type.getElementType()); - llvm::SmallVector broadcast_mask_shape(a_type.getRank(), 1); - broadcast_mask_shape[a_type.getRank() - 2] = m; mask = builder->create( loc, - RankedTensorType::get(broadcast_mask_shape, a_type.getElementType()), mask, GetI64ElementsAttr(llvm::SmallVector(num_batch_dims, 1), builder)); @@ -6787,7 +6765,7 @@ class ConvertQrOp : public OpRewritePattern { Value xa_mask = builder->create( loc, iota_mn, j, GetI64ElementsAttr({}, builder), ComparisonDirection::EQ); - a = builder->create(loc, a_type, xa_mask, new_x, a); + a = builder->create(loc, xa_mask, new_x, a); // vs[:, j] = v llvm::SmallVector vs_broadcast_dims(num_batch_dims + 1); @@ -6795,11 +6773,11 @@ class ConvertQrOp : public OpRewritePattern { Value vs_zeros = GetScalarConstOfType(a_type.getElementType(), loc, 0, builder); vs_zeros = builder->create( - loc, vs.getType(), vs_zeros, + loc, vs_zeros, GetI64ElementsAttr(vs.getType().cast().getShape(), builder)); auto vs_update = builder->create( - loc, vs.getType(), xa_mask, + loc, xa_mask, StaticBinaryBroadcast( loc, vs_zeros, v, GetI64ElementsAttr(vs_broadcast_dims, builder), *builder), @@ -6818,14 +6796,14 @@ class ConvertQrOp : public OpRewritePattern { Value taus_zeros = GetScalarConstOfType(a_type.getElementType(), loc, 0, builder); taus_zeros = builder->create( - loc, taus.getType(), taus_zeros, + loc, taus_zeros, GetI64ElementsAttr(taus.getType().cast().getShape(), builder)); Value taus_mask = builder->create( loc, iota_n, j, GetI64ElementsAttr({}, builder), ComparisonDirection::EQ); auto taus_update = builder->create( - loc, taus.getType(), taus_mask, + loc, taus_mask, StaticBinaryBroadcast( loc, taus_zeros, tau, GetI64ElementsAttr(tau_broadcast_dims, builder), *builder), @@ -6837,12 +6815,11 @@ class ConvertQrOp : public OpRewritePattern { Value zero = GetScalarConstOfType(a_type.getElementType(), loc, 0, rewriter); *vs = rewriter->create( - loc, a_type, zero, GetI64ElementsAttr(a_type.getShape(), rewriter)); + loc, zero, GetI64ElementsAttr(a_type.getShape(), rewriter)); auto taus_shape = llvm::to_vector<4>(batch_dims); taus_shape.push_back(n); *taus = rewriter->create( - loc, RankedTensorType::get(taus_shape, a_type.getElementType()), zero, - GetI64ElementsAttr(taus_shape, rewriter)); + loc, zero, GetI64ElementsAttr(taus_shape, rewriter)); SmallVector while_output; CreateWhile32(loc, std::min(m, n), qr_body_fn, {a, *vs, *taus}, @@ -6903,13 +6880,13 @@ class ConvertQrOp : public OpRewritePattern { Value zero = GetScalarConstOfType(getElementTypeOrSelf(vs.getType()), loc, 0, builder); zero = builder->create( - loc, vs.getType(), zero, + loc, zero, GetI64ElementsAttr(vs.getType().cast().getShape(), builder)); auto compare = builder->create( loc, iota_mn, j, GetI64ElementsAttr({}, builder), ComparisonDirection::GE); - auto y = builder->create(loc, vs.getType(), compare, zero, vs); + auto y = builder->create(loc, compare, zero, vs); // yv has shape [..., n, 1] auto precision = builder->getArrayAttr( @@ -6939,7 +6916,6 @@ class ConvertQrOp : public OpRewritePattern { auto w_shape = llvm::to_vector<4>(batch_dims); w_shape.append({m, n}); w = rewriter->create(loc, - RankedTensorType::get(w_shape, data_type), w, GetI64ElementsAttr(w_shape, rewriter)); auto v = SliceInMinorDims(loc, vs, {0}, {1}, rewriter); auto beta = SliceInMinorDims(loc, taus, {0}, {1}, rewriter);