From 40879a2623d121074ced1dee7f11db1e50649d48 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Wed, 19 Jun 2024 16:51:19 +0800 Subject: [PATCH] [WebNN EP] Enable Cast op for WebNN CPU backend (#20864) WebNN TFLite backend supports `cast` op but doesn't support casting to `uint64` data type. --- js/web/docs/webnn-operators.md | 2 +- onnxruntime/core/providers/webnn/builders/helper.h | 2 +- .../providers/webnn/builders/impl/cast_op_builder.cc | 12 +++++++++--- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index f060109b31fab..4c6dab84fa973 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -17,7 +17,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | ArgMin | ai.onnx(7-10, 11, 12, 13+) | argMin | ✓ | ✓ | WebNN CPU backend only supports 'select_last_index' value is 0 | | AveragePool | ai.onnx(7-9, 10, 11, 12-18, 19+) | averagePool2d | ✓ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'count_include_pad' value is 0 | | BatchNormalization | ai.onnx(7-8, 9-13, 14, 15+) | batchNormalization | ✗ | ✓ | Only supports 'training_mode' value is 0, one output | -| Cast | ai.onnx(7-8, 9-12, 13-18, 19-20, 21+) | cast | ✗ | ✓ | | +| Cast | ai.onnx(7-8, 9-12, 13-18, 19-20, 21+) | cast | ✓ | ✓ | WebNN CPU backend doesn't support casting to uint64 data type | | Ceil | ai.onnx(7-12, 13+) | ceil | ✓ | ✓ | | | Clip | ai.onnx(7-10, 11, 12, 13+) | clamp | ✓ | ✓ | WebNN CPU backend only supports 3 specific ranges: [0.0, infinity], [-1.0, 1.0], [0.0, 6.0] (Chromium issue: https://issues.chromium.org/issues/326156496) | | Concat | ai.onnx(7-10, 11-12, 13+) | concat | ✓ | ✓ | | diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index daecfcf457d70..7240fa37d9cc9 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -161,7 +161,7 @@ static const InlinedHashMap op_map = { {"ArgMin", {"argMin", true}}, {"AveragePool", {"averagePool2d", true}}, {"BatchNormalization", {"batchNormalization", false}}, - {"Cast", {"cast", false}}, + {"Cast", {"cast", true}}, {"Ceil", {"ceil", true}}, {"Clip", {"clamp", true}}, {"Concat", {"concat", true}}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc index f7d3d308d2f1d..a97d71b90de55 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc @@ -22,7 +22,7 @@ class CastOpBuilder : public BaseOpBuilder { // Operator support related. private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, - const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + const WebnnDeviceType device_type, const logging::Logger& logger) const override; }; // Add operator related. @@ -80,13 +80,19 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, bool CastOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, - const WebnnDeviceType /* device_type */, + const WebnnDeviceType device_type, const logging::Logger& logger) const { NodeAttrHelper helper(node); // Check cast output type. const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED); + + // WebNN CPU backend doesn't support casting to uint64 data type. + if (device_type == WebnnDeviceType::CPU && to_type == ONNX_NAMESPACE::TensorProto_DataType_UINT64) { + LOGS(logger, VERBOSE) << "Cast to uint64 is not supported for WebNN CPU backend."; + return false; + } if (!IsSupportedDataType(to_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "Invalid cast to type " << to_type << "."; + LOGS(logger, VERBOSE) << "WebNN doesn't support casting to type " << to_type << "."; return false; }