Skip to content

Commit

Permalink
[mhlo] Verifier for mhlo.ConvOp
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 448608547
  • Loading branch information
sdasgup3 authored and tensorflower-gardener committed May 14, 2022
1 parent 5ea8ba1 commit 9cf7bec
Show file tree
Hide file tree
Showing 11 changed files with 1,250 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1535,6 +1535,7 @@ def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]> {

let results = (outs HLO_Tensor);
let hasCustomHLOConverter = 1;
let hasVerifier = 1;

code extraClassDeclaration = [{
bool hasWindowReversal() {
Expand Down
305 changes: 305 additions & 0 deletions tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1791,6 +1791,311 @@ LogicalResult CollectivePermuteOp::verify() {
*this, source_target_pairs());
}

//===----------------------------------------------------------------------===//
// ConvOp
//===----------------------------------------------------------------------===//

namespace {
// Checks:
// P1. Same sizes for input, kernel and output spatial_dims.
// P2. Spatial and non-spatial dimentions (for input,kernel, &output) should
// be unique and in range [0, num_dims), where num_dims = rank of input
// (lhs/rhs) tensors.
//
// Note that the spatial + non-spatial dimensions may not cover all the
// dimensions in the range [0,num) because of the presence of 'unknown'
// dimensions (ref. cl/415132294).
LogicalResult isSpatialDimensionsValid(ConvOp op) {
auto input_spatial_dimensions =
op.dimension_numbers().getInputSpatialDimensions();
auto kernel_spatial_dimensions =
op.dimension_numbers().getKernelSpatialDimensions();
auto output_spatial_dimensions =
op.dimension_numbers().getOutputSpatialDimensions();

// P1.
if ((input_spatial_dimensions.size() != kernel_spatial_dimensions.size()) ||
(input_spatial_dimensions.size() != output_spatial_dimensions.size()))
return op.emitOpError() << "expects the same size for input, kernel and "
"output spatial-dimensions, but got "
<< input_spatial_dimensions.size() << ", "
<< kernel_spatial_dimensions.size() << ", and "
<< output_spatial_dimensions.size() << " resp.";

// P2.
SmallVector<int64_t> input_dnums(input_spatial_dimensions.size() + 2);
input_dnums[0] = op.dimension_numbers().getInputBatchDimension();
input_dnums[1] = op.dimension_numbers().getInputFeatureDimension();
std::copy(input_spatial_dimensions.begin(), input_spatial_dimensions.end(),
input_dnums.begin() + 2);

SmallVector<int64_t> window_dnums(kernel_spatial_dimensions.size() + 2);
window_dnums[0] = op.dimension_numbers().getKernelInputFeatureDimension();
window_dnums[1] = op.dimension_numbers().getKernelOutputFeatureDimension();
std::copy(kernel_spatial_dimensions.begin(), kernel_spatial_dimensions.end(),
window_dnums.begin() + 2);

SmallVector<int64_t> output_dnums(output_spatial_dimensions.size() + 2);
output_dnums[0] = op.dimension_numbers().getOutputBatchDimension();
output_dnums[1] = op.dimension_numbers().getOutputFeatureDimension();
std::copy(output_spatial_dimensions.begin(), output_spatial_dimensions.end(),
output_dnums.begin() + 2);

auto num_dims = op.lhs().getType().cast<RankedTensorType>().getRank();
const auto in_range = [num_dims](int64_t i) {
return 0 <= i && i < num_dims;
};

if (!llvm::all_of(input_dnums, in_range) ||
!llvm::all_of(window_dnums, in_range) ||
!llvm::all_of(output_dnums, in_range))
return op.emitOpError() << "expects input, kernel, and output "
"dimension-numbers to be in-range [0, "
<< num_dims << ").";

const auto has_duplicates = [](SmallVector<int64_t>& dnums) {
std::sort(dnums.begin(), dnums.end());
auto last = std::unique(dnums.begin(), dnums.end());
return last != dnums.end();
};

if (has_duplicates(input_dnums))
return op.emitOpError()
<< "expects input dimension-numbers to be unique, got {"
<< input_dnums << "}.";

if (has_duplicates(window_dnums))
return op.emitOpError()
<< "expects kernel dimension-numbers to be unique, got {"
<< window_dnums << "}.";

if (has_duplicates(output_dnums))
return op.emitOpError()
<< "expects output dimension-numbers to be unique, got {"
<< output_dnums << "}.";

return success();
}

// Verifies the following properties:
// P1. The input, kernel, and output spatial-dimentions are valid.
// P2. Given,
// input-dimensions: b * input-spatial-dims * f
// kernel-dimensions: kernel-spatial-dims * i * o
// output-dimensions: b' * out-spatial-dims * f'
// where b = input-batch-dims
// where f = input-feature-dims
// where i = kernel-input-feature-dims
// where o = kernel-output-feature-dims
// where b' = output-batch-dims
// where f' = output-feature-dims
// Check the following properties w.r.t feature_group_count (fgc) and
// batch_group_count (bgc).
// fgc > 0, bgc > 1 and !(fgc > 1 && bgc > 1)
// b % bgc == 0
// f % fgc == 0 and i = f / fgc
// o (or f') % bgc == 0 and o (or f') % fgc == 0
LogicalResult verifyConvolutionAttributes(ConvOp op) {
// P1.
if (failed(isSpatialDimensionsValid(op))) return failure();

// P2.
const int64_t feature_group_count = op.feature_group_count();
const int64_t batch_group_count = op.batch_group_count();

if (feature_group_count <= 0)
return op.emitOpError()
<< "expects feature_group_count to be a positive number, got "
<< feature_group_count << ".";

if (batch_group_count <= 0)
return op.emitOpError()
<< "expects batch_group_count to be a positive number, got "
<< batch_group_count << ".";

if (batch_group_count > 1 && feature_group_count > 1)
return op.emitOpError()
<< "expects batch_group_count and feature_group_count not to be "
"both greater than 1. Got "
<< batch_group_count << " and " << feature_group_count << " resp.";

auto lhs_type = op.lhs().getType().cast<RankedTensorType>();
const int64_t input_features =
lhs_type.getShape()[op.dimension_numbers().getInputFeatureDimension()];
const int64_t input_batch =
lhs_type.getShape()[op.dimension_numbers().getInputBatchDimension()];

auto rhs_type = op.rhs().getType().cast<RankedTensorType>();
const int64_t kernel_input_features =
rhs_type
.getShape()[op.dimension_numbers().getKernelInputFeatureDimension()];
const int64_t kernel_output_features =
rhs_type
.getShape()[op.dimension_numbers().getKernelOutputFeatureDimension()];

if (!isDynamicDimSize(kernel_output_features)) {
if (kernel_output_features % batch_group_count != 0)
return op.emitOpError() << "expects output feature dimension size ("
<< kernel_output_features
<< ") to be a multiple of "
"batch_group_count. Got batch_group_count = "
<< batch_group_count << ".";

if (kernel_output_features % feature_group_count != 0)
return op.emitOpError()
<< "expects kernel output feature dimension ("
<< kernel_output_features
<< ") to be divisible by "
"feature_group_count. For feature_group_count = "
<< feature_group_count << ".";
}

if (!isDynamicDimSize(input_features)) {
if (input_features % feature_group_count != 0)
return op.emitOpError()
<< "expects input feature dimension (" << input_features
<< ") to be a multiple of "
"feature_group_count. Got feature_group_count = "
<< feature_group_count << ".";

if (!isDynamicDimSize(kernel_input_features) &&
input_features / feature_group_count != kernel_input_features)
return op.emitOpError()
<< "expects input feature dimension (" << input_features
<< ") / "
"feature_group_count = kernel input feature dimension ("
<< kernel_input_features
<< "). Got feature_group_count = " << feature_group_count << ".";
}

if (!isDynamicDimSize(input_batch) && input_batch % batch_group_count != 0)
return op.emitOpError() << "expects input batch dimension (" << input_batch
<< ") to be divisible by "
"batch_group_count. Got batch_group_count = "
<< batch_group_count << ".";

return success();
}

// Infer the return-shape of ConvOp.
// Precondition:
// 1. Input args to ConvOp 'op' are RankedTypes.
// 2. rank-of(input-type) == rank-of(output-type)
SmallVector<int64_t> inferConvOpReturnShape(
ConvOp op, const ArrayRef<WindowDimension> window) {
// We keep the 'unknown' dimensions (cl/415132294) as it is in the
// output-shape. To do that we initilize the output dimensions with the shape
// of the return-type and updates only the spatial + non-spatial dimensions.
// Precondition 2 ensures that size of output-shape == size of input-shape.
SmallVector<int64_t> output_dimensions =
to_vector(op.getResult().getType().cast<ShapedType>().getShape());

// Infer the output spatial dimensions.
auto lhs_type = op.lhs().getType().cast<RankedTensorType>();
auto input_spatial_dims = op.dimension_numbers().getInputSpatialDimensions();
auto num_spatial_dims = input_spatial_dims.size();
SmallVector<int64_t> input_spatial_dim_vals(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i)
input_spatial_dim_vals[i] = lhs_type.getShape()[input_spatial_dims[i]];

auto window_output_shape =
inferWindowOutputShape(input_spatial_dim_vals, window);

for (int i = 0; i < window.size(); ++i)
output_dimensions[op.dimension_numbers().getOutputSpatialDimensions()[i]] =
window_output_shape[i];

// Infer the output-batch-dimension and output-feature-dimension.
auto rhs_type = op.rhs().getType().cast<RankedTensorType>();
const int64_t input_batch =
lhs_type.getShape()[op.dimension_numbers().getInputBatchDimension()];
const int64_t kernel_output_features =
rhs_type
.getShape()[op.dimension_numbers().getKernelOutputFeatureDimension()];

output_dimensions[op.dimension_numbers().getOutputBatchDimension()] =
isDynamicDimSize(input_batch) ? ShapedType::kDynamicSize
: input_batch / op.batch_group_count();
output_dimensions[op.dimension_numbers().getOutputFeatureDimension()] =
kernel_output_features;

return output_dimensions;
}
} // namespace

/*
* We intend to verify the following properties
* P1. Verify the input, kernel types.
* P2. Verify the convolution atributes.
* P3. Verify and collect the window atributes.
* P4. Verify the return shape.
* TODO(b/232574102): Verify the element-type of return-value.
*/
LogicalResult ConvOp::verify() {
auto lhs_type = lhs().getType().dyn_cast<RankedTensorType>();
auto rhs_type = rhs().getType().dyn_cast<RankedTensorType>();

if (!lhs_type || !rhs_type) return success();

// P1.
int num_dims = lhs_type.getRank();
if (num_dims != rhs_type.getRank())
return emitOpError()
<< "expects convolution arguments to have same number of "
"dimensions. Got: "
<< lhs_type << " and " << rhs_type << ".";

if (num_dims < 2)
return emitOpError()
<< "expects convolution arguments to have >= 2 dimensions. "
"Got: "
<< lhs_type << " and " << rhs_type << ".";

// P2.
if (failed(verifyConvolutionAttributes(*this))) return failure();

// P3.
auto kernel_spatial_dimensions =
dimension_numbers().getKernelSpatialDimensions();
SmallVector<int64_t> window_dimensions(kernel_spatial_dimensions.size());
for (size_t i = 0; i < window_dimensions.size(); i++)
window_dimensions[i] = rhs_type.getShape()[kernel_spatial_dimensions[i]];

auto padding_or_err = convertNx2Attribute(this->padding(), getLoc());
if (failed(padding_or_err)) return failure();
SmallVector<std::pair<int64_t, int64_t>> padding = *padding_or_err;

auto window_or_err = verifyWindowAttributesAndInferWindowDimensions(
window_dimensions, convertDenseIntAttr(window_strides()), padding,
convertDenseIntAttr(lhs_dilation()), convertDenseIntAttr(rhs_dilation()),
getLoc());
if (failed(window_or_err)) return failure();

// P4.
auto actual_return_type = getResult().getType().cast<TensorType>();
auto actual_return_element_type = actual_return_type.getElementType();
if (!actual_return_type.hasRank()) return success();

auto actual_return_ranked_type = actual_return_type.cast<RankedTensorType>();
if (num_dims != actual_return_ranked_type.getRank())
return emitOpError() << "expects rank of convolution return-type to be "
"equal to input-ranks ("
<< num_dims << "), but got "
<< actual_return_ranked_type.getRank() << ".";

auto expected_return_shape = inferConvOpReturnShape(*this, *window_or_err);
auto expected_return_type =
RankedTensorType::get(expected_return_shape, actual_return_element_type);
if (failed(verifyCompatibleShape(expected_return_type,
actual_return_ranked_type)))
return emitOpError()
<< "has shape mismatch between the expected return-type ("
<< expected_return_type << ") and actual return-type ("
<< actual_return_ranked_type << ").";

return success();
}

//===----------------------------------------------------------------------===//
// ConvertOp
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 9cf7bec

Please sign in to comment.