Skip to content

Commit

Permalink
Add SegmentMax-16 reference implementation (#29047)
Browse files Browse the repository at this point in the history
### Details:
- The original PR
(#28788) has been
mistakenly force-merged due to a mistake in merge queue settings. It was
later reverted, so this is the "new" Ref PR.
 - Add reference implementation
 - Add tests

### Related PRs:
 - #28103
 - #28698
 - #28979
 - #28999

### Tickets:
 - CVS-158917

---------

Signed-off-by: p-wysocki <[email protected]>
Co-authored-by: Roman Kazantsev <[email protected]>
Co-authored-by: Pawel Raasz <[email protected]>
Co-authored-by: Katarzyna Mitrus <[email protected]>
  • Loading branch information
4 people authored Feb 18, 2025
1 parent e58f38f commit 18e97b6
Show file tree
Hide file tree
Showing 11 changed files with 373 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/core/include/openvino/op/util/attr_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ enum class PadMode { CONSTANT = 0, EDGE, REFLECT, SYMMETRIC };
OPENVINO_API
std::ostream& operator<<(std::ostream& s, const PadMode& type);

/// \brief Fill modes for the `SegmentMax` operator.
/// \brief Fill modes to set default value for operators like `SegmentMax`.
enum class FillMode { ZERO = 0, LOWEST };

OPENVINO_API
Expand Down
1 change: 1 addition & 0 deletions src/core/include/openvino/opsets/opset16_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ _OPENVINO_OP_REG(ShapeOf, ov::op::v3)
// New operations added in opset16
_OPENVINO_OP_REG(Identity, ov::op::v16)
_OPENVINO_OP_REG(ISTFT, ov::op::v16)
_OPENVINO_OP_REG(SegmentMax, ov::op::v16)
55 changes: 55 additions & 0 deletions src/core/reference/include/openvino/reference/segment_max.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <algorithm>
#include <limits>
#include <vector>

#include "openvino/core/shape.hpp"

namespace ov::reference {

template <typename T, typename T_idx, std::enable_if_t<std::is_same<std::decay_t<T_idx>, int64_t>::value>* = nullptr>
void segment_max(const T* data,
const Shape& data_shape,
const T_idx* segment_ids,
T* out,
const Shape& output_shape,
const T empty_segment_value) {
const T_idx num_segments = output_shape[0];
const auto inner_dim_size = shape_size(data_shape.begin() + 1, data_shape.end());

// Initialize output with empty_segment_value
std::fill(out, out + num_segments * inner_dim_size, empty_segment_value);

// Iterate over each element in the first dimension
for (size_t i = 0; i < data_shape[0]; ++i) {
const T_idx segment_id = segment_ids[i];
if (segment_id >= num_segments) {
continue;
}
// Iterate over each element in the inner dimensions
for (size_t j = 0; j < inner_dim_size; ++j) {
const size_t index = i * inner_dim_size + j;
const size_t out_index = segment_id * inner_dim_size + j;
// Update the maximum value for the current segment and inner dimension
out[out_index] = std::max(out[out_index], data[index]);
}
}
}

template <typename T, typename T_idx, std::enable_if_t<!std::is_same<std::decay_t<T_idx>, int64_t>::value>* = nullptr>
void segment_max(const T* data,
const Shape& data_shape,
const T_idx* segment_ids,
T* out,
const Shape& output_shape,
const T empty_segment_value) {
std::vector<int64_t> segment_ids_int64(segment_ids, segment_ids + data_shape[0]);
segment_max(data, data_shape, segment_ids_int64.data(), out, output_shape, empty_segment_value);
}

} // namespace ov::reference
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ std::vector<TRShape> shape_infer(const SegmentMax* op,

// validate num_segments input
const auto num_segments_available = op->inputs().size() == 3;
const auto num_segments = num_segments_available ? get_input_const_data_as_shape<TRShape>(op, 2, tensor_accessor)
: ov::optional<TRShape>{};
ov::optional<TRShape> num_segments;
if (num_segments_available) {
num_segments = get_input_const_data_as_shape<TRShape>(op, 2, tensor_accessor);
}

if (num_segments_available) {
const auto& num_segments_shape = input_shapes[2];
NODE_SHAPE_INFER_CHECK(op,
Expand Down
2 changes: 1 addition & 1 deletion src/core/tests/opset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ INSTANTIATE_TEST_SUITE_P(opset,
OpsetTestParams{ov::get_opset13, 186},
OpsetTestParams{ov::get_opset14, 188},
OpsetTestParams{ov::get_opset15, 199},
OpsetTestParams{ov::get_opset16, 5}),
OpsetTestParams{ov::get_opset16, 6}),
OpsetTestNameGenerator{});

class MyOpOld : public ov::op::Op {
Expand Down
20 changes: 10 additions & 10 deletions src/core/tests/type_prop/segment_max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

namespace ov::test {
using op::v0::Constant, op::v0::Parameter, op::v1::Add, op::v1::ReduceMax, op::v1::StridedSlice, op::v3::ShapeOf;
using testing::HasSubstr;

class TypePropSegmentMaxTest : public TypePropOpTest<op::v16::SegmentMax> {};

Expand Down Expand Up @@ -69,45 +70,44 @@ TEST_F(TypePropSegmentMaxTest, incorrect_inputs) {
const auto num_segments_f32 = std::make_shared<Parameter>(element::f32, PartialShape{});
OV_EXPECT_THROW(std::ignore = make_op(data, segment_ids, num_segments_f32, op::FillMode::LOWEST),
ov::NodeValidationFailure,
testing::HasSubstr("The element type of the num_segments input be i32 or i64."));
HasSubstr("The element type of the num_segments input be i32 or i64."));
}
{
const auto segment_ids_f32 = std::make_shared<Parameter>(element::f32, PartialShape{3});
OV_EXPECT_THROW(std::ignore = make_op(data, segment_ids_f32, num_segments, op::FillMode::LOWEST),
ov::NodeValidationFailure,
testing::HasSubstr("The element type of the segment_ids input be i32 or i64."));
HasSubstr("The element type of the segment_ids input be i32 or i64."));
}
{
const auto segment_ids_nd = std::make_shared<Parameter>(element::i32, PartialShape{2, 3});
OV_EXPECT_THROW(std::ignore = make_op(data, segment_ids_nd, num_segments, op::FillMode::LOWEST),
ov::NodeValidationFailure,
testing::HasSubstr("segment_ids must be a 1D input."));
HasSubstr("segment_ids must be a 1D input."));
}
{
const auto num_segments_nd = std::make_shared<Parameter>(element::i32, PartialShape{1});
OV_EXPECT_THROW(std::ignore = make_op(data, segment_ids, num_segments_nd, op::FillMode::LOWEST),
ov::NodeValidationFailure,
testing::HasSubstr("num_segments must be a scalar input."));
HasSubstr("num_segments must be a scalar input."));
}
{
const auto segment_ids_unsorted =
std::make_shared<Constant>(element::i32, Shape{3}, std::vector<int64_t>{1, 0, 1});
OV_EXPECT_THROW(std::ignore = make_op(data, segment_ids_unsorted, num_segments, op::FillMode::LOWEST),
ov::NodeValidationFailure,
testing::HasSubstr("segment_ids must be sorted."));
HasSubstr("segment_ids must be sorted."));
}
{
const auto data_scalar = std::make_shared<Parameter>(element::i32, PartialShape{});
OV_EXPECT_THROW(std::ignore = make_op(data_scalar, segment_ids, num_segments, op::FillMode::LOWEST),
ov::NodeValidationFailure,
testing::HasSubstr("The data input cannot be a scalar."));
HasSubstr("The data input cannot be a scalar."));
}
{
const auto segment_ids_short = std::make_shared<Constant>(element::i32, Shape{2}, std::vector<int64_t>{1, 0});
OV_EXPECT_THROW(
std::ignore = make_op(data, segment_ids_short, num_segments, op::FillMode::LOWEST),
ov::NodeValidationFailure,
testing::HasSubstr("The number of elements in segment_ids must match the first dimension of data."));
OV_EXPECT_THROW(std::ignore = make_op(data, segment_ids_short, num_segments, op::FillMode::LOWEST),
ov::NodeValidationFailure,
HasSubstr("The number of elements in segment_ids must match the first dimension of data."));
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/plugins/template/backend/ops/ops_evaluates.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,3 +558,7 @@ extern template bool evaluate_node<ov::op::v15::SearchSorted>(std::shared_ptr<ov
extern template bool evaluate_node<ov::op::v16::Identity>(std::shared_ptr<ov::Node> node,
ov::TensorVector& outputs,
const ov::TensorVector& inputs);

extern template bool evaluate_node<ov::op::v16::SegmentMax>(std::shared_ptr<ov::Node> node,
ov::TensorVector& outputs,
const ov::TensorVector& inputs);
80 changes: 80 additions & 0 deletions src/plugins/template/backend/ops/segment_max.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/reference/segment_max.hpp"

#include "element_visitor.hpp"
#include "evaluate_node.hpp"
#include "segment_max_shape_inference.hpp"

template <ov::element::Type_t ET_data, ov::element::Type_t ET_idx>
bool evaluate_index_type(const std::shared_ptr<ov::op::v16::SegmentMax>& op,
ov::TensorVector& outputs,
const ov::TensorVector& inputs) {
using T_data = typename ov::element_type_traits<ET_data>::value_type;
using T_idx = typename ov::element_type_traits<ET_idx>::value_type;
auto input_shapes = std::vector<ov::PartialShape>{op->get_input_shape(0), op->get_input_shape(1)};
if (op->inputs().size() == 3) {
input_shapes.emplace_back(op->get_input_shape(2));
}
const auto output_shape =
ov::op::v16::shape_infer(op.get(), input_shapes, make_tensor_accessor(inputs)).front().to_shape();
outputs.front().set_shape(output_shape);
const auto empty_segment_value =
op->get_fill_mode() == ov::op::FillMode::ZERO ? T_data(0) : std::numeric_limits<T_data>::lowest();
ov::reference::segment_max(inputs[0].data<const T_data>(),
inputs[0].get_shape(),
inputs[1].data<const T_idx>(),
outputs[0].data<T_data>(),
outputs[0].get_shape(),
empty_segment_value);
return true;
}

template <ov::element::Type_t ET_data>
bool evaluate_data_type(const std::shared_ptr<ov::op::v16::SegmentMax>& op,
ov::TensorVector& outputs,
const ov::TensorVector& inputs) {
const auto& index_type = op->get_input_element_type(1);
using ov::op::v16::SegmentMax;
using namespace ov::element;
switch (index_type) {
case i32:
return evaluate_index_type<ET_data, i32>(ov::as_type_ptr<SegmentMax>(op), outputs, inputs);
case i64:
return evaluate_index_type<ET_data, i64>(ov::as_type_ptr<SegmentMax>(op), outputs, inputs);
default:
OPENVINO_THROW("Unhandled index type ", index_type, " in evaluate_node()");
}
}

template <>
bool evaluate_node<ov::op::v16::SegmentMax>(std::shared_ptr<ov::Node> node,
ov::TensorVector& outputs,
const ov::TensorVector& inputs) {
const auto& element_type = node->get_output_element_type(0);

using ov::op::v16::SegmentMax;
using namespace ov::element;
switch (element_type) {
case i8:
return evaluate_data_type<i8>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
case i32:
return evaluate_data_type<i32>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
case i64:
return evaluate_data_type<i64>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
case u8:
return evaluate_data_type<u8>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
case u32:
return evaluate_data_type<u32>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
case u64:
return evaluate_data_type<u64>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
case f16:
return evaluate_data_type<f16>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
case f32:
return evaluate_data_type<f32>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
default:
OPENVINO_THROW("Unhandled data type ", element_type, " in evaluate_node()");
}
}
1 change: 1 addition & 0 deletions src/plugins/template/backend/opset_int_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ _OPENVINO_OP_REG(SearchSorted, ov::op::v15)

_OPENVINO_OP_REG(Identity, ov::op::v16)
_OPENVINO_OP_REG(ISTFT, ov::op::v16)
_OPENVINO_OP_REG(SegmentMax, ov::op::v16)

_OPENVINO_OP_REG(AUGRUCell, ov::op::internal)
_OPENVINO_OP_REG(AUGRUSequence, ov::op::internal)
Expand Down
Loading

0 comments on commit 18e97b6

Please sign in to comment.