diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 58014bdddde3f0..d9b27a6858354f 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -1535,6 +1535,7 @@ def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]> { let results = (outs HLO_Tensor); let hasCustomHLOConverter = 1; + let hasVerifier = 1; code extraClassDeclaration = [{ bool hasWindowReversal() { diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc index e561d7df5d70ca..4959759dfb0ae3 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -1791,6 +1791,311 @@ LogicalResult CollectivePermuteOp::verify() { *this, source_target_pairs()); } +//===----------------------------------------------------------------------===// +// ConvOp +//===----------------------------------------------------------------------===// + +namespace { +// Checks: +// P1. Same sizes for input, kernel and output spatial_dims. +// P2. Spatial and non-spatial dimentions (for input,kernel, &output) should +// be unique and in range [0, num_dims), where num_dims = rank of input +// (lhs/rhs) tensors. +// +// Note that the spatial + non-spatial dimensions may not cover all the +// dimensions in the range [0,num) because of the presence of 'unknown' +// dimensions (ref. cl/415132294). +LogicalResult isSpatialDimensionsValid(ConvOp op) { + auto input_spatial_dimensions = + op.dimension_numbers().getInputSpatialDimensions(); + auto kernel_spatial_dimensions = + op.dimension_numbers().getKernelSpatialDimensions(); + auto output_spatial_dimensions = + op.dimension_numbers().getOutputSpatialDimensions(); + + // P1. + if ((input_spatial_dimensions.size() != kernel_spatial_dimensions.size()) || + (input_spatial_dimensions.size() != output_spatial_dimensions.size())) + return op.emitOpError() << "expects the same size for input, kernel and " + "output spatial-dimensions, but got " + << input_spatial_dimensions.size() << ", " + << kernel_spatial_dimensions.size() << ", and " + << output_spatial_dimensions.size() << " resp."; + + // P2. + SmallVector input_dnums(input_spatial_dimensions.size() + 2); + input_dnums[0] = op.dimension_numbers().getInputBatchDimension(); + input_dnums[1] = op.dimension_numbers().getInputFeatureDimension(); + std::copy(input_spatial_dimensions.begin(), input_spatial_dimensions.end(), + input_dnums.begin() + 2); + + SmallVector window_dnums(kernel_spatial_dimensions.size() + 2); + window_dnums[0] = op.dimension_numbers().getKernelInputFeatureDimension(); + window_dnums[1] = op.dimension_numbers().getKernelOutputFeatureDimension(); + std::copy(kernel_spatial_dimensions.begin(), kernel_spatial_dimensions.end(), + window_dnums.begin() + 2); + + SmallVector output_dnums(output_spatial_dimensions.size() + 2); + output_dnums[0] = op.dimension_numbers().getOutputBatchDimension(); + output_dnums[1] = op.dimension_numbers().getOutputFeatureDimension(); + std::copy(output_spatial_dimensions.begin(), output_spatial_dimensions.end(), + output_dnums.begin() + 2); + + auto num_dims = op.lhs().getType().cast().getRank(); + const auto in_range = [num_dims](int64_t i) { + return 0 <= i && i < num_dims; + }; + + if (!llvm::all_of(input_dnums, in_range) || + !llvm::all_of(window_dnums, in_range) || + !llvm::all_of(output_dnums, in_range)) + return op.emitOpError() << "expects input, kernel, and output " + "dimension-numbers to be in-range [0, " + << num_dims << ")."; + + const auto has_duplicates = [](SmallVector& dnums) { + std::sort(dnums.begin(), dnums.end()); + auto last = std::unique(dnums.begin(), dnums.end()); + return last != dnums.end(); + }; + + if (has_duplicates(input_dnums)) + return op.emitOpError() + << "expects input dimension-numbers to be unique, got {" + << input_dnums << "}."; + + if (has_duplicates(window_dnums)) + return op.emitOpError() + << "expects kernel dimension-numbers to be unique, got {" + << window_dnums << "}."; + + if (has_duplicates(output_dnums)) + return op.emitOpError() + << "expects output dimension-numbers to be unique, got {" + << output_dnums << "}."; + + return success(); +} + +// Verifies the following properties: +// P1. The input, kernel, and output spatial-dimentions are valid. +// P2. Given, +// input-dimensions: b * input-spatial-dims * f +// kernel-dimensions: kernel-spatial-dims * i * o +// output-dimensions: b' * out-spatial-dims * f' +// where b = input-batch-dims +// where f = input-feature-dims +// where i = kernel-input-feature-dims +// where o = kernel-output-feature-dims +// where b' = output-batch-dims +// where f' = output-feature-dims +// Check the following properties w.r.t feature_group_count (fgc) and +// batch_group_count (bgc). +// fgc > 0, bgc > 1 and !(fgc > 1 && bgc > 1) +// b % bgc == 0 +// f % fgc == 0 and i = f / fgc +// o (or f') % bgc == 0 and o (or f') % fgc == 0 +LogicalResult verifyConvolutionAttributes(ConvOp op) { + // P1. + if (failed(isSpatialDimensionsValid(op))) return failure(); + + // P2. + const int64_t feature_group_count = op.feature_group_count(); + const int64_t batch_group_count = op.batch_group_count(); + + if (feature_group_count <= 0) + return op.emitOpError() + << "expects feature_group_count to be a positive number, got " + << feature_group_count << "."; + + if (batch_group_count <= 0) + return op.emitOpError() + << "expects batch_group_count to be a positive number, got " + << batch_group_count << "."; + + if (batch_group_count > 1 && feature_group_count > 1) + return op.emitOpError() + << "expects batch_group_count and feature_group_count not to be " + "both greater than 1. Got " + << batch_group_count << " and " << feature_group_count << " resp."; + + auto lhs_type = op.lhs().getType().cast(); + const int64_t input_features = + lhs_type.getShape()[op.dimension_numbers().getInputFeatureDimension()]; + const int64_t input_batch = + lhs_type.getShape()[op.dimension_numbers().getInputBatchDimension()]; + + auto rhs_type = op.rhs().getType().cast(); + const int64_t kernel_input_features = + rhs_type + .getShape()[op.dimension_numbers().getKernelInputFeatureDimension()]; + const int64_t kernel_output_features = + rhs_type + .getShape()[op.dimension_numbers().getKernelOutputFeatureDimension()]; + + if (!isDynamicDimSize(kernel_output_features)) { + if (kernel_output_features % batch_group_count != 0) + return op.emitOpError() << "expects output feature dimension size (" + << kernel_output_features + << ") to be a multiple of " + "batch_group_count. Got batch_group_count = " + << batch_group_count << "."; + + if (kernel_output_features % feature_group_count != 0) + return op.emitOpError() + << "expects kernel output feature dimension (" + << kernel_output_features + << ") to be divisible by " + "feature_group_count. For feature_group_count = " + << feature_group_count << "."; + } + + if (!isDynamicDimSize(input_features)) { + if (input_features % feature_group_count != 0) + return op.emitOpError() + << "expects input feature dimension (" << input_features + << ") to be a multiple of " + "feature_group_count. Got feature_group_count = " + << feature_group_count << "."; + + if (!isDynamicDimSize(kernel_input_features) && + input_features / feature_group_count != kernel_input_features) + return op.emitOpError() + << "expects input feature dimension (" << input_features + << ") / " + "feature_group_count = kernel input feature dimension (" + << kernel_input_features + << "). Got feature_group_count = " << feature_group_count << "."; + } + + if (!isDynamicDimSize(input_batch) && input_batch % batch_group_count != 0) + return op.emitOpError() << "expects input batch dimension (" << input_batch + << ") to be divisible by " + "batch_group_count. Got batch_group_count = " + << batch_group_count << "."; + + return success(); +} + +// Infer the return-shape of ConvOp. +// Precondition: +// 1. Input args to ConvOp 'op' are RankedTypes. +// 2. rank-of(input-type) == rank-of(output-type) +SmallVector inferConvOpReturnShape( + ConvOp op, const ArrayRef window) { + // We keep the 'unknown' dimensions (cl/415132294) as it is in the + // output-shape. To do that we initilize the output dimensions with the shape + // of the return-type and updates only the spatial + non-spatial dimensions. + // Precondition 2 ensures that size of output-shape == size of input-shape. + SmallVector output_dimensions = + to_vector(op.getResult().getType().cast().getShape()); + + // Infer the output spatial dimensions. + auto lhs_type = op.lhs().getType().cast(); + auto input_spatial_dims = op.dimension_numbers().getInputSpatialDimensions(); + auto num_spatial_dims = input_spatial_dims.size(); + SmallVector input_spatial_dim_vals(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) + input_spatial_dim_vals[i] = lhs_type.getShape()[input_spatial_dims[i]]; + + auto window_output_shape = + inferWindowOutputShape(input_spatial_dim_vals, window); + + for (int i = 0; i < window.size(); ++i) + output_dimensions[op.dimension_numbers().getOutputSpatialDimensions()[i]] = + window_output_shape[i]; + + // Infer the output-batch-dimension and output-feature-dimension. + auto rhs_type = op.rhs().getType().cast(); + const int64_t input_batch = + lhs_type.getShape()[op.dimension_numbers().getInputBatchDimension()]; + const int64_t kernel_output_features = + rhs_type + .getShape()[op.dimension_numbers().getKernelOutputFeatureDimension()]; + + output_dimensions[op.dimension_numbers().getOutputBatchDimension()] = + isDynamicDimSize(input_batch) ? ShapedType::kDynamicSize + : input_batch / op.batch_group_count(); + output_dimensions[op.dimension_numbers().getOutputFeatureDimension()] = + kernel_output_features; + + return output_dimensions; +} +} // namespace + +/* + * We intend to verify the following properties + * P1. Verify the input, kernel types. + * P2. Verify the convolution atributes. + * P3. Verify and collect the window atributes. + * P4. Verify the return shape. + * TODO(b/232574102): Verify the element-type of return-value. + */ +LogicalResult ConvOp::verify() { + auto lhs_type = lhs().getType().dyn_cast(); + auto rhs_type = rhs().getType().dyn_cast(); + + if (!lhs_type || !rhs_type) return success(); + + // P1. + int num_dims = lhs_type.getRank(); + if (num_dims != rhs_type.getRank()) + return emitOpError() + << "expects convolution arguments to have same number of " + "dimensions. Got: " + << lhs_type << " and " << rhs_type << "."; + + if (num_dims < 2) + return emitOpError() + << "expects convolution arguments to have >= 2 dimensions. " + "Got: " + << lhs_type << " and " << rhs_type << "."; + + // P2. + if (failed(verifyConvolutionAttributes(*this))) return failure(); + + // P3. + auto kernel_spatial_dimensions = + dimension_numbers().getKernelSpatialDimensions(); + SmallVector window_dimensions(kernel_spatial_dimensions.size()); + for (size_t i = 0; i < window_dimensions.size(); i++) + window_dimensions[i] = rhs_type.getShape()[kernel_spatial_dimensions[i]]; + + auto padding_or_err = convertNx2Attribute(this->padding(), getLoc()); + if (failed(padding_or_err)) return failure(); + SmallVector> padding = *padding_or_err; + + auto window_or_err = verifyWindowAttributesAndInferWindowDimensions( + window_dimensions, convertDenseIntAttr(window_strides()), padding, + convertDenseIntAttr(lhs_dilation()), convertDenseIntAttr(rhs_dilation()), + getLoc()); + if (failed(window_or_err)) return failure(); + + // P4. + auto actual_return_type = getResult().getType().cast(); + auto actual_return_element_type = actual_return_type.getElementType(); + if (!actual_return_type.hasRank()) return success(); + + auto actual_return_ranked_type = actual_return_type.cast(); + if (num_dims != actual_return_ranked_type.getRank()) + return emitOpError() << "expects rank of convolution return-type to be " + "equal to input-ranks (" + << num_dims << "), but got " + << actual_return_ranked_type.getRank() << "."; + + auto expected_return_shape = inferConvOpReturnShape(*this, *window_or_err); + auto expected_return_type = + RankedTensorType::get(expected_return_shape, actual_return_element_type); + if (failed(verifyCompatibleShape(expected_return_type, + actual_return_ranked_type))) + return emitOpError() + << "has shape mismatch between the expected return-type (" + << expected_return_type << ") and actual return-type (" + << actual_return_ranked_type << ")."; + + return success(); +} + //===----------------------------------------------------------------------===// // ConvertOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 0b25655d9617e3..c2182e9a594af8 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -2055,28 +2055,26 @@ struct DepthwiseConvOpConversion : public OpConversionPattern { op, "non-one lhs- dialation unsupported yet"); } - if (const mhlo::ConvDimensionNumbersAttr& dimension_numbers = - op.dimension_numbers()) { - // Make sure that this is 2-D convolution. - const auto spatial_rank = - llvm::size(dimension_numbers.getInputSpatialDimensions()); - if (spatial_rank != 2) { - return rewriter.notifyMatchFailure(op, - "only support 2-D cases for now"); - } + const mhlo::ConvDimensionNumbersAttr& dimension_numbers = + op.dimension_numbers(); + // Make sure that this is 2-D convolution. + const auto spatial_rank = + llvm::size(dimension_numbers.getInputSpatialDimensions()); + if (spatial_rank != 2) { + return rewriter.notifyMatchFailure(op, "only support 2-D cases for now"); + } - // Make sure that this is depthwise convolution. - int64_t input_feature_dim = dimension_numbers.getInputFeatureDimension(); - int64_t input_feature_count = - op.lhs().getType().cast().getDimSize(input_feature_dim); - if (op.feature_group_count() != input_feature_count) { - return rewriter.notifyMatchFailure(op, "not depth-wise convolution"); - } + // Make sure that this is depthwise convolution. + int64_t input_feature_dim = dimension_numbers.getInputFeatureDimension(); + int64_t input_feature_count = + op.lhs().getType().cast().getDimSize(input_feature_dim); + if (op.feature_group_count() != input_feature_count) { + return rewriter.notifyMatchFailure(op, "not depth-wise convolution"); + } - // Make sure that this convolution has a canonical form. - if (!HasCanonicalDimensionNumbers(dimension_numbers)) { - return rewriter.notifyMatchFailure(op, "does not have canonical form"); - } + // Make sure that this convolution has a canonical form. + if (!HasCanonicalDimensionNumbers(dimension_numbers)) { + return rewriter.notifyMatchFailure(op, "does not have canonical form"); } DenseIntElementsAttr window_strides; @@ -2127,10 +2125,38 @@ struct DepthwiseConvOpConversion : public OpConversionPattern { return llvm::to_vector<2>(llvm::seq(start, end)); }; - if (filter_dims[2] * filter_dims[3] != op.feature_group_count()) { + int64_t kernel_input_feature_dimension = + dimension_numbers.getKernelInputFeatureDimension(); + int64_t kernel_output_feature_dimension = + dimension_numbers.getKernelOutputFeatureDimension(); + if (filter_dims[kernel_input_feature_dimension] * + filter_dims[kernel_output_feature_dimension] != + op.feature_group_count()) { // For cases where channel multiplier != 1 + + // Reshaping filter shape + // [filter_height, filter_width, 1, kernel-output-feature]. + // to + // [filter_height, filter_width, feature_group_count, + // kernel-output-feature/feature_group_count ] + SmallVector reshaped_filter_dims; + reshaped_filter_dims.assign(filter_dims.begin(), filter_dims.end()); + auto reshaped_filter = filter; + if (filter_dims[kernel_input_feature_dimension] == 1) { + reshaped_filter_dims[kernel_input_feature_dimension] = + op.feature_group_count(); + reshaped_filter_dims[kernel_output_feature_dimension] /= + op.feature_group_count(); + auto reshaped_filter_type = RankedTensorType::get( + reshaped_filter_dims, + op.rhs().getType().cast().getElementType()); + + reshaped_filter = + rewriter.create(loc, reshaped_filter_type, filter); + } + auto output_dims = result_type.getShape(); - auto channel_multiplier = filter_dims[3]; + auto channel_multiplier = reshaped_filter_dims[3]; SmallVector reshaped_output_dims; reshaped_output_dims.assign(output_dims.begin(), output_dims.end()); reshaped_output_dims.push_back(channel_multiplier); @@ -2143,7 +2169,7 @@ struct DepthwiseConvOpConversion : public OpConversionPattern { auto reshaped_output_type = RankedTensorType::get( reshaped_output_dims, result_type.getElementType()); auto conv = rewriter.create( - op.getLoc(), reshaped_output_type, ValueRange{input, filter}, + loc, reshaped_output_type, ValueRange{input, reshaped_filter}, ValueRange{zero_tensor}, window_strides, rhs_dilation, PruneAttributeList(op)); diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/conv_op_verifier.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/conv_op_verifier.mlir new file mode 100644 index 00000000000000..c4a2eaf1c9a9cd --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/conv_op_verifier.mlir @@ -0,0 +1,809 @@ +// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | FileCheck %s + +// ----- + +// Valid: Generic convolution + +func.func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> + tensor<100x28x28x1xf32> { + %result = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<2> : tensor<2x2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> + tensor<100x28x28x1xf32> + func.return %result : tensor<100x28x28x1xf32> +} + +// Valid: Test convolution i8xi8 -> i32. + +func.func @convolution_upcast(%arg0 : tensor<100x26x26x32xi8>, + %arg1 : tensor<3x3x1x32xi8>) -> tensor<100x28x28x1xi32> { + %result = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<2> : tensor<2x2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<100x26x26x32xi8>, tensor<3x3x1x32xi8>) -> tensor<100x28x28x1xi32> + func.return %result : tensor<100x28x28x1xi32> +} + +// Valid: Empty spatial dimensions + +// CHECK: func @conv_empty_spatial_dimensions +// CHECK: mhlo.convolution +// CHECK-SAME: dim_numbers = [b, f]x[i, o]->[b, f] +// CHECK-SAME: window = {stride = [], pad = [], lhs_dilate = [], +// CHECK-SAME: rhs_dilate = [], reverse = []} +func.func @conv_empty_spatial_dimensions(%arg0: tensor<3x2xf16>, + %arg1: tensor<2x2xf16>) -> tuple> { + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, f]x[i, o]->[b, f], + window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], + reverse = []} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } + : (tensor<3x2xf16>, tensor<2x2xf16>) -> tensor<3x2xf16> + %1 = "mhlo.tuple"(%0) : (tensor<3x2xf16>) -> tuple> + func.return %1 : tuple> +} + +// ----- + +func.func @invalid_conv_dimensions(%arg0: tensor<2x4x5x2xf32>, + %arg1: tensor<2x2x1x6xf32>) -> tensor<2x3x4x6xf32> { + // expected-error@+1 {{expects input dimension-numbers to be unique, got {0, 0}.}} + %1 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv, + feature_group_count = 2 : i64, + someattr} : (tensor<2x4x5x2xf32>, tensor<2x2x1x6xf32>) -> + tensor<2x3x4x6xf32> + func.return %1 : tensor<2x3x4x6xf32> +} + +// ----- + +func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects convolution arguments to have same number of dimensions. Got: 'tensor<1x8x8x207xf32>' and 'tensor<3x3x207xf32>'.}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x207xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + +func.func @invalid_conv_dimensions(%arg0: tensor<1xf32>, %arg1: tensor<3xf32>) + -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects convolution arguments to have >= 2 dimensions. Got: 'tensor<1xf32>' and 'tensor<3xf32>'.}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<1xf32>, tensor<3xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + +func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects the same size for input, kernel and output spatial-dimensions, but got 3, 2, and 2 resp.}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, 2, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + +func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects the same size for input, kernel and output spatial-dimensions, but got 2, 3, and 2 resp.}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, 2, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + +func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects the same size for input, kernel and output spatial-dimensions, but got 2, 2, and 3 resp.}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, 2, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + +func.func @invalid_conv_dimensions(%arg0 : tensor<100x26x26x32xf32>, + %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { + // expected-error@+1 {{expects input, kernel, and output dimension-numbers to be in-range [0, 4).}} + %result = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<2> : tensor<2x2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> + tensor<100x28x28x1xf32> + func.return %result : tensor<100x28x28x1xf32> +} + +// ----- + +func.func @invalid_conv_dimensions(%arg0 : tensor<100x26x26x32xf32>, + %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { + // expected-error@+1 {{expects kernel dimension-numbers to be unique, got {0, 2, 3, 3}.}} + %result = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<2> : tensor<2x2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> + tensor<100x28x28x1xf32> + func.return %result : tensor<100x28x28x1xf32> +} + +// ----- + +func.func @invalid_conv_dimensions(%arg0 : tensor<100x26x26x32xf32>, + %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { + // expected-error@+1 {{expects output dimension-numbers to be unique, got {0, 3, 3, 3}.}} + %result = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<2> : tensor<2x2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> + tensor<100x28x28x1xf32> + func.return %result : tensor<100x28x28x1xf32> +} + +// ----- + +func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{op expects batch_group_count to be a positive number, got 0.}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 0 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + +func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{op expects feature_group_count to be a positive number, got 0.}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 0 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + +func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects batch_group_count and feature_group_count not to be both greater than 1. Got 2 and 2 resp.}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 2 : i64, + feature_group_count = 2 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + +func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects output feature dimension size (16) to be a multiple of batch_group_count. Got batch_group_count = 3.}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 3 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + +func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x20x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects input feature dimension (207) to be a multiple of feature_group_count. Got feature_group_count = 2.}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 2 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x20x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + +func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x20x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects input feature dimension (207) / feature_group_count = kernel input feature dimension (20). Got feature_group_count = 1.}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x20x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + +func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x69x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects kernel output feature dimension (16) to be divisible by feature_group_count. For feature_group_count = 3.}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 3 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x69x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + +func.func @invalid_conv_dimensions(%arg0: tensor<5x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects input batch dimension (5) to be divisible by batch_group_count. Got batch_group_count = 2.}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 2 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<5x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + +func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects window-strides to have same dimension-size as size of window dimensions (2), but got: 1.}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1], pad = [[1, 1], [1, 1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + +func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects base-dilation factors to have same dimension-size as size of window dimensions (2), but got: 1.}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], + lhs_dilate = [1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + +func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects window-dilation factors to have same dimension-size as size of window dimensions (2), but got: 1.}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], + lhs_dilate = [1, 1], rhs_dilate = [1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + +func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects padding-entries to have same dimension-size as size of window dimensions (2), but got: 1.}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + +func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + // expected-error@+1 {{Expected array with 2 elements, got 4 elements instead}} + window = {stride = [1, 1], pad = [[1, 1, 1, 1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + +func.func @invalid_conv_dimensions(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> + tensor<100x28x28x1xf32> { + // expected-error@+1 {{expects padding-entries to have same dimension-size as size of window dimensions (2), but got: 3.}} + %result = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<2> : tensor<6xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> + tensor<100x28x28x1xf32> + func.return %result : tensor<100x28x28x1xf32> +} + +// ----- + +func.func @invalid_conv_dimensions(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> + tensor<100x28x28x1xf32> { + // expected-error@+1 {{expects the padding-entries to have even number of elements, but got 5 elements.}} + %result = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<2> : tensor<5xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> + tensor<100x28x28x1xf32> + func.return %result : tensor<100x28x28x1xf32> +} + +// ----- + +func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<0x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects window to have positive value for 0-th window dimension, but got 0.}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1,1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} : + (tensor<1x8x8x207xf32>, tensor<0x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + +func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects window to have positive stride for 1-th window dimension, but got 0.}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 0], pad = [[1, 1], [1,1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + +func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects window to have positive base dilation factor for 0-th window dimension, but got 0.}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1,1]], + lhs_dilate = [0, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + +func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects window to have positive window dilation factor for 0-th window dimension, but got 0.}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1,1]], + lhs_dilate = [1, 1], rhs_dilate = [0, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + +// Invalid rank of output-type. + +func.func @invalid_conv_return_type(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x16xf32> { + // expected-error @+1 {{expects rank of convolution return-type to be equal to input-ranks (4), but got 3.}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x16xf32> + func.return %0 : tensor<1x8x16xf32> +} + +// ----- + +// Invalid batch dimension in output-type. Should be equal to +// input-batch-dimension / batch_group_count. + +func.func @invalid_conv_return_type(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<2x8x8x16xf32> { + // expected-error@+1 {{nvolution' op has shape mismatch between the expected return-type ('tensor<1x8x8x16xf32>') and actual return-type ('tensor<2x8x8x16xf32>').}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1,1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<2x8x8x16xf32> + func.return %0 : tensor<2x8x8x16xf32> +} + +// ----- + +// Invalid feature dimension in output-type. Should be equal to +// kernel_output_feature_dimension. + +func.func @invalid_conv_return_type(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x32xf32> { + // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x8x8x16xf32>') and actual return-type ('tensor<1x8x8x32xf32>').}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1,1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x32xf32> + func.return %0 : tensor<1x8x8x32xf32> +} + +// ----- + +// The following tests checks the inferred output-type of ConvOp. We +// deliberately put an invalid output-type in these tests so that the +// inffered-type can be highlighted in the error message. + +// Dynamic input-batch-dimension +func.func @invalid_conv_dynamic_shapes(%arg0: tensor, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> { + // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor') and actual return-type ('tensor<1x1x1x1xf32>').}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1,1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor, tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> + func.return %0 : tensor<1x1x1x1xf32> +} + +// ----- + +// Dynamic input-feature-dimension: No effect on output dimensions. +func.func @invalid_conv_dynamic_shapes(%arg0: tensor<1x8x8x?xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> { + // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x8x8x16xf32>') and actual return-type ('tensor<1x1x1x1xf32>').}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1,1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<1x8x8x?xf32>, tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> + func.return %0 : tensor<1x1x1x1xf32> +} + +// ----- + +// Dynamic input-spatial-dimension +func.func @invalid_conv_dynamic_shapes(%arg0: tensor<1x?x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> { + // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x?x8x16xf32>') and actual return-type ('tensor<1x1x1x1xf32>').}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1,1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<1x?x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> + func.return %0 : tensor<1x1x1x1xf32> +} + +// ----- + +// Dynamic kernel-input-feature-dimension: No effect on output dimensions. +func.func @invalid_conv_dynamic_shapes(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x?x16xf32>) -> tensor<1x1x1x1xf32> { + // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x8x8x16xf32>') and actual return-type ('tensor<1x1x1x1xf32>').}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1,1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x?x16xf32>) -> tensor<1x1x1x1xf32> + func.return %0 : tensor<1x1x1x1xf32> +} + +// ----- + +// Dynamic kernel-output-feature-dimension +func.func @check_inferred_type_with_dynamic_input_dims(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x?xf32>) -> tensor<1x1x1x1xf32> { + // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x8x8x?xf32>') and actual return-type ('tensor<1x1x1x1xf32>').}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1,1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x207x?xf32>) -> tensor<1x1x1x1xf32> + func.return %0 : tensor<1x1x1x1xf32> +} + +// ----- + +// Dynamic kernel-spatial-dimension +func.func @check_inferred_type_with_dynamic_input_dims(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x?x207x16xf32>) -> tensor<1x1x1x1xf32> { + // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x8x?x16xf32>') and actual return-type ('tensor<1x1x1x1xf32>').}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1,1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : + (tensor<1x8x8x207xf32>, tensor<3x?x207x16xf32>) -> tensor<1x1x1x1xf32> + func.return %0 : tensor<1x1x1x1xf32> +} + diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-lhlo.mlir index 08bb2f364e6bab..db142cd8e40452 100644 --- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-lhlo.mlir @@ -473,10 +473,10 @@ func.func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { // ----- // CHECK-LABEL: func @conv -func.func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) - -> tensor<3x5x5x4xf32> { +func.func @conv(%input: tensor<3x2x4x3xf32>, %filter : tensor<2x2x3x4xf32>) + -> tensor<2x1x2x3xf32> { %c0 = arith.constant 0 : index - // CHECK: %[[OUT:.*]] = memref.alloc() : memref<3x5x5x4xf32> + // CHECK: %[[OUT:.*]] = memref.alloc() : memref<2x1x2x3xf32> // CHECK: lmhlo.convolution(%{{.+}}, %{{.+}}, %[[OUT]]) // CHECK-SAME{LITERAL}: window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]} %out = "mhlo.convolution"(%filter, %input) { @@ -496,8 +496,8 @@ func.func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64> - } : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> - func.return %out : tensor<3x5x5x4xf32> + } : (tensor<2x2x3x4xf32>, tensor<3x2x4x3xf32>) -> tensor<2x1x2x3xf32> + func.return %out : tensor<2x1x2x3xf32> } // ----- diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir index b150f5fb76a1b6..e7e353c8ea279e 100644 --- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir @@ -2863,7 +2863,7 @@ func.func @linalg.conv_1d_nwc(%arg0: tensor, %arg1: tensor<2x?x?xf32> output_spatial_dimensions = [1] >, feature_group_count = 1 : i64, - padding = dense<[[0], [0]]> : tensor<2x1xi64>, + padding = dense<[[0, 0]]> : tensor<1x2xi64>, rhs_dilation = dense<1> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>, someattr @@ -2890,7 +2890,7 @@ func.func @linalg.conv_1d_nwc(%arg0: tensor, %arg1: tensor<2x?x?xf32> // ----- func.func @conv_2d_nhwc_hwcf(%arg0: tensor, %arg1: tensor<3x2x?x?xf32>) - -> tensor { + -> tensor { %0 = "mhlo.convolution"(%arg0, %arg1) { batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv, %arg1: tensor<3x2x?x?xf padding = dense<[[0, 0], [0, 0]]> : tensor<2x2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64> - } : (tensor, tensor<3x2x?x?xf32>) -> tensor - func.return %0 : tensor + } : (tensor, tensor<3x2x?x?xf32>) -> tensor + func.return %0 : tensor } // CHECK-LABEL: func @conv_2d_nhwc_hwcf // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] @@ -2918,14 +2918,14 @@ func.func @conv_2d_nhwc_hwcf(%arg0: tensor, %arg1: tensor<3x2x?x?xf // CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor // CHECK: %[[C3:.+]] = arith.constant 3 : index // CHECK: %[[DIM3:.+]] = tensor.dim %[[ARG1]], %[[C3]] : tensor<3x2x?x?xf32> -// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 2, 3, %[[DIM3]]] +// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 2, 4, %[[DIM3]]] // CHECK: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]] // CHECK: linalg.conv_2d_nhwc // CHECK-SAME: {dilations = dense<1> : tensor<2xi64> // CHECK-SAME: strides = dense<1> : tensor<2xi64>} // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor<3x2x?x?xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor +// CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor // ----- @@ -2945,7 +2945,7 @@ func.func @conv_3d_ndhwc_dhwcf(%arg0: tensor, %arg1: tensor<2x2x2 output_spatial_dimensions = [1, 2, 3] >, feature_group_count = 1 : i64, - padding = dense<[[0, 0, 0], [0, 0, 0]]> : tensor<2x3xi64>, + padding = dense<[[0, 0], [0, 0], [0, 0]]> : tensor<3x2xi64>, rhs_dilation = dense<1> : tensor<3xi64>, window_strides = dense<1> : tensor<3xi64> } : (tensor, tensor<2x2x2x?x?xf32>) -> tensor @@ -3028,25 +3028,25 @@ func.func @linalg.conv_2D_padding_test1(%arg0: tensor<1x33x1x1xf16>, %arg1: tens // CHECK-LABEL: func @linalg.conv_2D_padding_test2 // CHECK-SAME: (%[[FILTER:.*]]: tensor<1x33x1x1xf16>, %[[INPUT:.*]]: tensor<400x1024x1024x1xf16>) func.func @linalg.conv_2D_padding_test2(%arg0: tensor<1x33x1x1xf16>, %arg1: tensor<400x1024x1024x1xf16>) - -> tensor<400x1024x1024x1xf16> { - %0 = mhlo.convolution(%arg1, %arg0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[8, 8], [16, 16]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<400x1024x1024x1xf16>, tensor<1x33x1x1xf16>) -> (tensor<400x1024x1024x1xf16>) - func.return %0 : tensor<400x1024x1024x1xf16> + -> tensor<400x1040x1024x1xf16> { + %0 = mhlo.convolution(%arg1, %arg0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[8, 8], [16, 16]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<400x1024x1024x1xf16>, tensor<1x33x1x1xf16>) -> (tensor<400x1040x1024x1xf16>) + return %0 : tensor<400x1040x1024x1xf16> } -// CHECK-NEXT: %[[INIT:.*]] = linalg.init_tensor [400, 1024, 1024, 1] : tensor<400x1024x1024x1xf16> +// CHECK-NEXT: %[[INIT:.*]] = linalg.init_tensor [400, 1040, 1024, 1] : tensor<400x1040x1024x1xf16> // CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f16 -// CHECK-NEXT: %[[FILL:.*]] = linalg.fill ins(%[[ZERO]] : f16) outs(%[[INIT]] : tensor<400x1024x1024x1xf16>) -> tensor<400x1024x1024x1xf16> +// CHECK-NEXT: %[[FILL:.*]] = linalg.fill ins(%[[ZERO]] : f16) outs(%[[INIT]] : tensor<400x1040x1024x1xf16>) -> tensor<400x1040x1024x1xf16> // CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f16 // CHECK-NEXT: %[[PAD:.*]] = tensor.pad %[[INPUT]] low[0, 8, 16, 0] high[0, 8, 16, 0] { // CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): // CHECK-NEXT: tensor.yield %[[ZERO]] : f16 // CHECK-NEXT: } : tensor<400x1024x1024x1xf16> to tensor<400x1040x1056x1xf16> -// CHECK-NEXT: %[[RESULT:.*]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%2, %arg0 : tensor<400x1040x1056x1xf16>, tensor<1x33x1x1xf16>) outs(%1 : tensor<400x1024x1024x1xf16>) -> tensor<400x1024x1024x1xf16> -// CHECK-NEXT: return %[[RESULT]] : tensor<400x1024x1024x1xf16> +// CHECK-NEXT: %[[RESULT:.*]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%2, %arg0 : tensor<400x1040x1056x1xf16>, tensor<1x33x1x1xf16>) outs(%1 : tensor<400x1040x1024x1xf16>) -> tensor<400x1040x1024x1xf16> +// CHECK-NEXT: return %[[RESULT]] : tensor<400x1040x1024x1xf16> // ----- func.func @depthwise_conv(%arg0: tensor<2x4x5x2xf32>, - %arg1: tensor<2x2x2x3xf32>) -> tensor<2x3x4x6xf32> { + %arg1: tensor<2x2x1x6xf32>) -> tensor<2x3x4x6xf32> { %0 = "mhlo.convolution"(%arg0, %arg1) { batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv, padding = dense<0> : tensor<2x2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>, - someattr} : (tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<2x3x4x6xf32> + someattr} : (tensor<2x4x5x2xf32>, tensor<2x2x1x6xf32>) -> tensor<2x3x4x6xf32> func.return %0 : tensor<2x3x4x6xf32> } // CHECK: func @depthwise_conv // CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[FILTER:[a-zA-Z0-9_]*]] -// CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 4, 2, 3] : tensor<2x3x4x2x3xf32> -// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32> -// CHECK: %[[OUT:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm + +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2, 3]] : tensor<2x2x1x6xf32> into tensor<24xf32> +// CHECK: %[[CAST:.+]] = tensor.cast %[[COLLAPSE]] : tensor<24xf32> to tensor<24xf32> +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[CAST]] {{\[}}[0, 1, 2, 3]] : tensor<24xf32> into tensor<2x2x2x3xf32> +// CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 4, 2, 3] : tensor<2x3x4x2x3xf32> +// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32> +// CHECK: %[[OUT:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm // CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, someattr, strides = dense<1> : tensor<2xi64>} -// CHECK-SAME: ins(%[[IN]], %[[FILTER]] : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>) +// CHECK-SAME: ins(%[[IN]], %[[EXPAND]] : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>) // CHECK-SAME: outs(%[[FILL]] : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32> -// CHECK: %{{.+}} = tensor.collapse_shape %[[OUT]] +// CHECK: %{{.+}} = tensor.collapse_shape %[[OUT]] // CHECK-SAME: [0], [1], [2], [3, 4] // CHECK-SAME: : tensor<2x3x4x2x3xf32> into tensor<2x3x4x6xf32> @@ -3085,7 +3089,7 @@ func.func @depthwise_conv(%arg0: tensor<2x4x5x2xf32>, func.func @depthwise_conv_with_padding( %arg0: tensor<2x4x5x2xf32>, - %arg1: tensor<2x2x2x3xf32>) -> tensor<2x3x6x6xf32> { + %arg1: tensor<2x2x1x4xf32>) -> tensor<2x3x6x4xf32> { %0 = "mhlo.convolution"(%arg0, %arg1) { batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv : tensor<2x2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>, - someattr} : (tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<2x3x6x6xf32> - func.return %0 : tensor<2x3x6x6xf32> + someattr} : (tensor<2x4x5x2xf32>, tensor<2x2x1x4xf32>) -> tensor<2x3x6x4xf32> + func.return %0 : tensor<2x3x6x4xf32> } // CHECK: func @depthwise_conv_with_padding // CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]] @@ -3113,17 +3117,24 @@ func.func @depthwise_conv_with_padding( // CHECK: %[[PAD:.*]] = tensor.pad %[[IN]] low[0, 0, 1, 0] high[0, 0, 1, 0] { // CHECK: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): // CHECK: tensor.yield %[[ZERO]] : f32 -// CHECK } : tensor<2x4x5x2xf32> to tensor<2x4x7x1xf32> -// CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 6, 2, 3] : tensor<2x3x6x2x3xf32> +// CHECK } : tensor<2x4x5x2xf32> to tensor<2x4x7x2xf32> +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[FILTER]] +// CHECK-SAME: [0, 1, 2, 3] +// CHECK-SAME: : tensor<2x2x1x4xf32> into tensor<16xf32> +// CHECK: %[[CAST:.+]] = tensor.cast %[[COLLAPSE]] : tensor<16xf32> to tensor<16xf32> +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[CAST]] +// CHECK-SAME: [0, 1, 2, 3] +// CHECK-SAME: tensor<16xf32> into tensor<2x2x2x2xf32> +// CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 6, 2, 2] : tensor<2x3x6x2x2xf32> // CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<2x3x6x2x3xf32>) -> tensor<2x3x6x2x3xf32> +// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<2x3x6x2x2xf32>) -> tensor<2x3x6x2x2xf32> // CHECK: %[[OUT:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm // CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, someattr, strides = dense<1> : tensor<2xi64>} -// CHECK-SAME: ins(%[[PAD]], %[[FILTER]] : tensor<2x4x7x2xf32>, tensor<2x2x2x3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<2x3x6x2x3xf32>) -> tensor<2x3x6x2x3xf32> +// CHECK-SAME: ins(%[[PAD]], %[[EXPAND]] : tensor<2x4x7x2xf32>, tensor<2x2x2x2xf32>) +// CHECK-SAME: outs(%[[FILL]] : tensor<2x3x6x2x2xf32>) -> tensor<2x3x6x2x2xf32> // CHECK: %{{.+}} = tensor.collapse_shape %[[OUT]] // CHECK-SAME: [0], [1], [2], [3, 4] -// CHECK-SAME: : tensor<2x3x6x2x3xf32> into tensor<2x3x6x6xf32> +// CHECK-SAME: : tensor<2x3x6x2x2xf32> into tensor<2x3x6x4xf32> // ----- @@ -3166,7 +3177,7 @@ func.func @depthwise_conv_multiplier_1(%arg0: tensor<1x113x113x96xf32>, func.func @depthwise_conv_multiplier_1_with_padding( %arg0: tensor<1x113x113x96xf32>, - %arg1: tensor<3x3x1x96xf32>) -> tensor<1x56x56x96xf32> { + %arg1: tensor<3x3x1x96xf32>) -> tensor<1x57x58x96xf32> { %0 = "mhlo.convolution"(%arg0, %arg1) { batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv : tensor<2x2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, - window_strides = dense<2> : tensor<2xi64>} : (tensor<1x113x113x96xf32>, tensor<3x3x1x96xf32>) -> tensor<1x56x56x96xf32> - func.return %0 : tensor<1x56x56x96xf32> + window_strides = dense<2> : tensor<2xi64>} : (tensor<1x113x113x96xf32>, tensor<3x3x1x96xf32>) -> tensor<1x57x58x96xf32> + func.return %0 : tensor<1x57x58x96xf32> } // CHECK: func @depthwise_conv_multiplier_1_with_padding // CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]] @@ -3194,16 +3205,16 @@ func.func @depthwise_conv_multiplier_1_with_padding( // CHECK: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): // CHECK: tensor.yield %[[ZERO]] : f32 // CHECK } : tensor<1x113x113x96xf32> to tensor<1x115x117x96xf32> -// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 56, 56, 96] : tensor<1x56x56x96xf32> +// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 57, 58, 96] : tensor<1x57x58x96xf32> // CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32> +// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<1x57x58x96xf32>) -> tensor<1x57x58x96xf32> // CHECK: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] // CHECK-SAME: [0], [1], [2, 3] // CHECK-SAME: : tensor<3x3x1x96xf32> into tensor<3x3x96xf32> // CHECK: %{{.+}} = linalg.depthwise_conv_2d_nhwc_hwc // CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} // CHECK-SAME: ins(%[[PAD]], %[[RESHAPED_FILTER]] : tensor<1x115x117x96xf32>, tensor<3x3x96xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32> +// CHECK-SAME: outs(%[[FILL]] : tensor<1x57x58x96xf32>) -> tensor<1x57x58x96xf32> // ----- diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir index 1cf1109f46604c..5a1136163d638e 100644 --- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir @@ -2541,13 +2541,13 @@ module attributes { mhlo.conv = #mhlo.conv[b, 0, 1, f] // CHECK-SAME{LITERAL}: window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]} -func.func @convolution(%arg0: tensor<2x2x3x4xf32>, %arg1: tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> { +func.func @convolution(%arg0: tensor<2x2x3x4xf32>, %arg1: tensor<3x2x4x3xf32>) -> tensor<2x1x1x3xf32> { %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, 1, 0, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]} { batch_group_count = 1 : i64, feature_group_count = 1 : i64} - : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> - func.return %0 : tensor<3x5x5x4xf32> + : (tensor<2x2x3x4xf32>, tensor<3x2x4x3xf32>) -> tensor<2x1x1x3xf32> + func.return %0 : tensor<2x1x1x3xf32> } // ----- diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 4e3b81c2f984be..3f3a9496f28b8d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -1707,13 +1707,13 @@ func.func @convert_conv2d_negative_explicit_padding(%arg0: tensor<128x7x9x64xf32 // CHECK-LABEL: func @convert_depthwise_conv2d( // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x8x207xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3x1x3312xf32>) -> tensor<1x8x8x16xf32> { +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3x1x3312xf32>) -> tensor<1x8x8x3312xf32> { // CHECK: %[[CST:.*]] = arith.constant dense<[3, 3, 207, 16]> : tensor<4xi64> // CHECK: %[[VAL_2:.*]] = "tf.Reshape"(%[[VAL_1]], %[[CST]]) : (tensor<3x3x1x3312xf32>, tensor<4xi64>) -> tensor<3x3x207x16xf32> -// CHECK: %[[VAL_3:.*]] = "tf.DepthwiseConv2dNative"(%[[VAL_0]], %[[VAL_2]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> -// CHECK: return %[[VAL_3]] : tensor<1x8x8x16xf32> +// CHECK: %[[VAL_3:.*]] = "tf.DepthwiseConv2dNative"(%[[VAL_0]], %[[VAL_2]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x3312xf32> +// CHECK: return %[[VAL_3]] : tensor<1x8x8x3312xf32> // CHECK: } -func.func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x1x3312xf32>) -> tensor<1x8x8x16xf32> { +func.func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x1x3312xf32>) -> tensor<1x8x8x3312xf32> { %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv, %arg1: tensor< output_feature_dimension = 3, output_spatial_dimensions = [1, 2] >, feature_group_count = 207 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : - (tensor<1x8x8x207xf32>, tensor<3x3x1x3312xf32>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> + (tensor<1x8x8x207xf32>, tensor<3x3x1x3312xf32>) -> tensor<1x8x8x3312xf32> + func.return %0 : tensor<1x8x8x3312xf32> } // CHECK-LABEL: func @convert_conv2d_to_resize( diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir index 915e5fb1502974..9159aa8c96e56a 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir @@ -4346,7 +4346,7 @@ func.func @conv3d_backprop_input(%filter: tensor<3x3x3x1x6xf32>, %out_backprop: func.func @conv2d_backprop_filter( %input: tensor<100x28x28x1xf32>, %out_backprop: tensor<100x26x26x32xf32> - ) -> tensor<100x28x28x1xf32> { + ) -> tensor<3x3x1x32xf32> { // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1) // CHECK-SAME: dim_numbers = [f, 0, 1, b]x[i, 0, 1, o]->[0, 1, b, f] // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} @@ -4361,8 +4361,8 @@ func.func @conv2d_backprop_filter( padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true - } : (tensor<100x28x28x1xf32>, tensor<4xi32>, tensor<100x26x26x32xf32>) -> tensor<100x28x28x1xf32> - func.return %result : tensor<100x28x28x1xf32> + } : (tensor<100x28x28x1xf32>, tensor<4xi32>, tensor<100x26x26x32xf32>) -> tensor<3x3x1x32xf32> + func.return %result : tensor<3x3x1x32xf32> } // ----- @@ -4391,7 +4391,7 @@ func.func @conv2d_backprop_filter_grouped( // CHECK-LABEL: @conv3d_backprop_filter -func.func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> { +func.func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<3x3x3x1x6xf32> { // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1) // CHECK-SAME: dim_numbers = [f, 0, 1, 2, b]x[i, 0, 1, 2, o]->[0, 1, 2, b, f] // CHECK-SAME{LITERAL}: window = {stride = [1, 1, 1], pad = [[1, 1], [1, 1], [1, 1]], lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1]} @@ -4399,8 +4399,8 @@ func.func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: // CHECK-SAME: feature_group_count = 1 : i64 // CHECK: return %[[RESULT]] %filter_sizes = "tf.Const"() {value = dense<[3, 3, 3, 1, 6]> : tensor<5xi32>} : () -> tensor<5xi32> - %result = "tf.Conv3DBackpropFilterV2"(%input, %filter_sizes, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<2x8x8x8x1xf32>, tensor<5xi32>, tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> - func.return %result : tensor<2x8x8x8x1xf32> + %result = "tf.Conv3DBackpropFilterV2"(%input, %filter_sizes, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<2x8x8x8x1xf32>, tensor<5xi32>, tensor<2x8x8x8x6xf32>) -> tensor<3x3x3x1x6xf32> + func.return %result : tensor<3x3x3x1x6xf32> } // ----- diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 536125ffcdf4bd..3bb549aac4aa5d 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -4460,7 +4460,7 @@ func.func @conv3d_backprop_input(%filter: tensor<3x3x3x1x6xf32>, %out_backprop: func.func @conv2d_backprop_filter( %input: tensor<100x28x28x1xf32>, %out_backprop: tensor<100x26x26x32xf32> - ) -> tensor<100x28x28x1xf32> { + ) -> tensor<3x3x1x32xf32> { // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1) // CHECK-SAME: dim_numbers = [f, 0, 1, b]x[i, 0, 1, o]->[0, 1, b, f] // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} @@ -4475,8 +4475,8 @@ func.func @conv2d_backprop_filter( padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true - } : (tensor<100x28x28x1xf32>, tensor<4xi32>, tensor<100x26x26x32xf32>) -> tensor<100x28x28x1xf32> - func.return %result : tensor<100x28x28x1xf32> + } : (tensor<100x28x28x1xf32>, tensor<4xi32>, tensor<100x26x26x32xf32>) -> tensor<3x3x1x32xf32> + func.return %result : tensor<3x3x1x32xf32> } // ----- @@ -4505,7 +4505,7 @@ func.func @conv2d_backprop_filter_grouped( // CHECK-LABEL: @conv3d_backprop_filter -func.func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> { +func.func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<3x3x3x1x6xf32> { // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1) // CHECK-SAME: dim_numbers = [f, 0, 1, 2, b]x[i, 0, 1, 2, o]->[0, 1, 2, b, f] // CHECK-SAME{LITERAL}: window = {stride = [1, 1, 1], pad = [[1, 1], [1, 1], [1, 1]], lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1]} @@ -4513,8 +4513,8 @@ func.func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: // CHECK-SAME: feature_group_count = 1 : i64 // CHECK: return %[[RESULT]] %filter_sizes = "tf.Const"() {value = dense<[3, 3, 3, 1, 6]> : tensor<5xi32>} : () -> tensor<5xi32> - %result = "tf.Conv3DBackpropFilterV2"(%input, %filter_sizes, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<2x8x8x8x1xf32>, tensor<5xi32>, tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> - func.return %result : tensor<2x8x8x8x1xf32> + %result = "tf.Conv3DBackpropFilterV2"(%input, %filter_sizes, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<2x8x8x8x1xf32>, tensor<5xi32>, tensor<2x8x8x8x6xf32>) -> tensor<3x3x3x1x6xf32> + func.return %result : tensor<3x3x3x1x6xf32> } // ----- diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index 731da337cf0d73..13e70cee5c8571 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -256,15 +256,15 @@ add { // TODO(b/129422361) Potentially update when copy, reshape, and conv have actual // implementations with attributes, etc. // CHECK-LABEL: func private @test_conv( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<256x32x32x6xf32>) -> tuple> { +// CHECK-SAME: %[[VAL_0:.*]]: tensor<256x32x32x1xf32>) -> tuple> { %test_conv { - %arg0.1 = f32[256,32,32,6]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} + %arg0.1 = f32[256,32,32,1]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} - // CHECK-NEXT: %[[VAL_1:.*]] = "mhlo.copy"(%[[VAL_0]]) {xla_shape = "f32[256,32,32,6]{2,1,3,0}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32> - %copy.1 = f32[256,32,32,6]{2,1,3,0} copy(%arg0.1), metadata={op_name="HLO_Args"} + // CHECK-NEXT: %[[VAL_1:.*]] = "mhlo.copy"(%[[VAL_0]]) {xla_shape = "f32[256,32,32,1]{2,1,3,0}"} : (tensor<256x32x32x1xf32>) -> tensor<256x32x32x1xf32> + %copy.1 = f32[256,32,32,1]{2,1,3,0} copy(%arg0.1), metadata={op_name="HLO_Args"} - // CHECK-NEXT: %[[VAL_2:.*]] = "mhlo.reshape"(%[[VAL_1]]) {xla_shape = "f32[256,32,32,6]{2,1,3,0}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32> - %reshape.2 = f32[256,32,32,6]{2,1,3,0} reshape(%copy.1) + // CHECK-NEXT: %[[VAL_2:.*]] = "mhlo.reshape"(%[[VAL_1]]) {xla_shape = "f32[256,32,32,1]{2,1,3,0}"} : (tensor<256x32x32x1xf32>) -> tensor<256x32x32x1xf32> + %reshape.2 = f32[256,32,32,1]{2,1,3,0} reshape(%copy.1) // Note that double brackets "[[" have to be escaped as they denote variables // in FileCheck. The only way to do so is to drop into regex with "{{" @@ -276,15 +276,15 @@ add { // CHECK-SAME{LITERAL}: window = {stride = [4, 5], pad = [[44, 45], [60, 60]], lhs_dilate = [1, 1], rhs_dilate = [2, 3]} // CHECK-SAME: feature_group_count = 1 : i64 // CHECK-SAME: precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] - // CHECK-SAME: (tensor<256x32x32x6xf32>, tensor<2x2x1x1xf32>) -> tensor<16x30x30x256xf32> + // CHECK-SAME: (tensor<256x32x32x1xf32>, tensor<2x2x1x1xf32>) -> tensor<1x30x30x256xf32> - %convolution.4 = f32[16,30,30,256]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->f01b, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} + %convolution.4 = f32[1,30,30,256]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->f01b, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} - // CHECK-NEXT: %[[VAL_5:.*]] = "mhlo.reshape"(%[[VAL_4]]) : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32> - %reshape.5 = f32[256,30,30,16]{3,2,1,0} reshape(%convolution.4), metadata={op_name="HLO_Retvals"} + // CHECK-NEXT: %[[VAL_5:.*]] = "mhlo.reshape"(%[[VAL_4]]) : (tensor<1x30x30x256xf32>) -> tensor<256x30x30x1xf32> + %reshape.5 = f32[256,30,30,1]{3,2,1,0} reshape(%convolution.4), metadata={op_name="HLO_Retvals"} - // CHECK-NEXT: %[[VAL_6:.*]] = "mhlo.tuple"(%[[VAL_5]]) {xla_shape = {{.*}}} : (tensor<256x30x30x16xf32>) -> tuple> - ROOT %tuple.6 = (f32[256,30,30,16]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="HLO_Retvals"} + // CHECK-NEXT: %[[VAL_6:.*]] = "mhlo.tuple"(%[[VAL_5]]) {xla_shape = {{.*}}} : (tensor<256x30x30x1xf32>) -> tuple> + ROOT %tuple.6 = (f32[256,30,30,1]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="HLO_Retvals"} } // Test for padding attribute shape in convolution