Skip to content

Commit

Permalink
[WebNN EP] Enable Cast op for WebNN CPU backend (microsoft#20864)
Browse files Browse the repository at this point in the history
WebNN TFLite backend supports `cast` op but doesn't support casting to
`uint64` data type.
  • Loading branch information
Honry authored Jun 19, 2024
1 parent 35c430a commit 40879a2
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion js/web/docs/webnn-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| ArgMin | ai.onnx(7-10, 11, 12, 13+) | argMin ||| WebNN CPU backend only supports 'select_last_index' value is 0 |
| AveragePool | ai.onnx(7-9, 10, 11, 12-18, 19+) | averagePool2d ||| Only supports 4-D input, 2-D 'kernel_shape', 'count_include_pad' value is 0 |
| BatchNormalization | ai.onnx(7-8, 9-13, 14, 15+) | batchNormalization ||| Only supports 'training_mode' value is 0, one output |
| Cast | ai.onnx(7-8, 9-12, 13-18, 19-20, 21+) | cast | || |
| Cast | ai.onnx(7-8, 9-12, 13-18, 19-20, 21+) | cast | || WebNN CPU backend doesn't support casting to uint64 data type |
| Ceil | ai.onnx(7-12, 13+) | ceil ||| |
| Clip | ai.onnx(7-10, 11, 12, 13+) | clamp ||| WebNN CPU backend only supports 3 specific ranges: [0.0, infinity], [-1.0, 1.0], [0.0, 6.0] (Chromium issue: https://issues.chromium.org/issues/326156496) |
| Concat | ai.onnx(7-10, 11-12, 13+) | concat ||| |
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
{"ArgMin", {"argMin", true}},
{"AveragePool", {"averagePool2d", true}},
{"BatchNormalization", {"batchNormalization", false}},
{"Cast", {"cast", false}},
{"Cast", {"cast", true}},
{"Ceil", {"ceil", true}},
{"Clip", {"clamp", true}},
{"Concat", {"concat", true}},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class CastOpBuilder : public BaseOpBuilder {
// Operator support related.
private:
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -80,13 +80,19 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,

bool CastOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */,
const Node& node,
const WebnnDeviceType /* device_type */,
const WebnnDeviceType device_type,
const logging::Logger& logger) const {
NodeAttrHelper helper(node);
// Check cast output type.
const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED);

// WebNN CPU backend doesn't support casting to uint64 data type.
if (device_type == WebnnDeviceType::CPU && to_type == ONNX_NAMESPACE::TensorProto_DataType_UINT64) {
LOGS(logger, VERBOSE) << "Cast to uint64 is not supported for WebNN CPU backend.";
return false;
}
if (!IsSupportedDataType(to_type, webnn_supported_data_types)) {
LOGS(logger, VERBOSE) << "Invalid cast to type " << to_type << ".";
LOGS(logger, VERBOSE) << "WebNN doesn't support casting to type " << to_type << ".";
return false;
}

Expand Down

0 comments on commit 40879a2

Please sign in to comment.