Skip to content

Commit

Permalink
Print tensor dtypes as strings in shape inference (onnx#5856)
Browse files Browse the repository at this point in the history
Fix onnx#5782

Signed-off-by: isdanni <[email protected]>
  • Loading branch information
isdanni authored Jan 16, 2024
1 parent 54e704e commit 463ca88
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions onnx/defs/shape_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,11 @@ void UnionTypeInfo(const TypeProto& source_type, TypeProto& target_type) {

if (source_elem_type != target_elem_type) {
fail_type_inference(
"Mismatched tensor element type:", " inferred=", source_elem_type, " declared=", target_elem_type);
"Mismatched tensor element type:",
" inferred=",
Utils::DataTypeUtils::ToDataTypeString(source_elem_type),
" declared=",
Utils::DataTypeUtils::ToDataTypeString(target_elem_type));
}

UnionShapeInfo(source_type.tensor_type(), *target_type.mutable_tensor_type());
Expand All @@ -281,7 +285,11 @@ void UnionTypeInfo(const TypeProto& source_type, TypeProto& target_type) {
auto target_elem_type = target_type.sparse_tensor_type().elem_type();
if (source_elem_type != target_elem_type) {
fail_type_inference(
"Mismatched sparse tensor element type:", " inferred=", source_elem_type, " declared=", target_elem_type);
"Mismatched sparse tensor element type:",
" inferred=",
Utils::DataTypeUtils::ToDataTypeString(source_elem_type),
" declared=",
Utils::DataTypeUtils::ToDataTypeString(target_elem_type));
}
UnionShapeInfo(source_type.sparse_tensor_type(), *target_type.mutable_sparse_tensor_type());
} else if (target_case == TypeProto::ValueCase::kSequenceType) {
Expand Down

0 comments on commit 463ca88

Please sign in to comment.