Skip to content
This repository has been archived by the owner on Dec 12, 2024. It is now read-only.

Commit

Permalink
Merge branch 'main' into workflows
Browse files Browse the repository at this point in the history
  • Loading branch information
marbre authored Apr 4, 2024
2 parents 0218cfe + f896d3f commit 0fade03
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
18 changes: 9 additions & 9 deletions reference-implementation/include/emitc/tosa.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ inline Src minimum(Src x, Src y) {
}

template <typename Src, IsTensorOfType<Src, int32_t> = true>
inline Src mul(Src x, Src y, const int32_t shift) {
inline Src mul(Src x, Src y, const int8_t shift) {
// Adopted from
// https://git.mlplatform.org/tosa/reference_model.git/tree/reference_model/src/ops/ewise_binary.cc?id=df8626976df6c779bb30df9c5ceef689462109c0#n436
if (shift > 0) {
Expand Down Expand Up @@ -646,7 +646,7 @@ namespace {
// Common reduce function used by specialized TOSA reduce ops.
template <typename Dest, typename Src, typename Computation>
inline Dest reduce(Src operand, typename get_element_type<Src>::type initValue,
int64_t dimension, Computation computation) {
int32_t dimension, Computation computation) {
static_assert(is_tensor<Src>::value, "Expected tensor argument");
static_assert(is_tensor<Dest>::value, "Expected tensor result");

Expand Down Expand Up @@ -688,7 +688,7 @@ inline Dest reduce(Src operand, typename get_element_type<Src>::type initValue,

// ArgMaxOp
template <typename Dest, typename Src>
inline Dest argmax(Src operand, int64_t dimension) {
inline Dest argmax(Src operand, int32_t dimension) {
static_assert(is_tensor<Src>::value, "Expected tensor argument");
static_assert(is_tensor<Dest>::value, "Expected tensor result");

Expand Down Expand Up @@ -732,7 +732,7 @@ inline Dest argmax(Src operand, int64_t dimension) {

// ReduceAllOp
template <typename Dest, typename Src>
inline Dest reduce_all(Src input, int64_t dimension) {
inline Dest reduce_all(Src input, int32_t dimension) {
// ReduceAllOp takes only tensors with datatype bool according to the
// TOSA specifications.
using ET_Src = typename get_element_type<Src>::type;
Expand All @@ -750,7 +750,7 @@ inline Dest reduce_all(Src input, int64_t dimension) {

// ReduceAnyOp
template <typename Dest, typename Src>
inline Dest reduce_any(Src input, int64_t dimension) {
inline Dest reduce_any(Src input, int32_t dimension) {
// ReduceAnyOp takes only tensors with datatype bool according to the
// TOSA specifications.
using ET_Src = typename get_element_type<Src>::type;
Expand All @@ -768,7 +768,7 @@ inline Dest reduce_any(Src input, int64_t dimension) {

// ReduceMaxOp
template <typename Dest, typename Src>
inline Dest reduce_max(Src input, int64_t dimension) {
inline Dest reduce_max(Src input, int32_t dimension) {
using ET_Src = typename get_element_type<Src>::type;

auto f =
Expand All @@ -780,7 +780,7 @@ inline Dest reduce_max(Src input, int64_t dimension) {

// ReduceMinOp
template <typename Dest, typename Src>
inline Dest reduce_min(Src input, int64_t dimension) {
inline Dest reduce_min(Src input, int32_t dimension) {
using ET_Src = typename get_element_type<Src>::type;

auto f =
Expand All @@ -792,7 +792,7 @@ inline Dest reduce_min(Src input, int64_t dimension) {

// ReduceProdOp
template <typename Dest, typename Src>
inline Dest reduce_prod(Src input, int64_t dimension) {
inline Dest reduce_prod(Src input, int32_t dimension) {
using ET_Src = typename get_element_type<Src>::type;

return tosa::reduce<Dest, Src>(input, 1, dimension,
Expand All @@ -801,7 +801,7 @@ inline Dest reduce_prod(Src input, int64_t dimension) {

// ReduceSumOp
template <typename Dest, typename Src>
inline Dest reduce_sum(Src input, int64_t dimension) {
inline Dest reduce_sum(Src input, int32_t dimension) {
using ET_Src = typename get_element_type<Src>::type;

return tosa::reduce<Dest, Src>(input, 0, dimension, std::plus<ET_Src>{});
Expand Down
12 changes: 6 additions & 6 deletions reference-implementation/unittests/tosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ TEST(tosa, mul) {
Tensor2D<int32_t, 2, 2> t2{1, 2, 3, 4};

auto lambda_1d_int_shift = [&s2, &t2]() -> Tensor2D<int32_t, 2, 2> {
int32_t shift{2};
int8_t shift{2};
return tosa::mul(s2, t2, shift);
};

Expand Down Expand Up @@ -1136,7 +1136,7 @@ TEST(tosa, reduce_prod) {
TEST(tosa, reduce_sum) {
{
Tensor<int32_t, 2, 3> input{1, 2, 3, 4, 5, 6};
int64_t dimension = 0;
int32_t dimension = 0;
Tensor<int32_t, 3> expected_result{5, 7, 9};
Tensor<int32_t, 3> result =
tosa::reduce_sum<Tensor<int32_t, 3>>(input, dimension);
Expand All @@ -1145,7 +1145,7 @@ TEST(tosa, reduce_sum) {
}
{
Tensor<int32_t, 2, 3> input{1, 2, 3, 4, 5, 6};
int64_t dimension = 1;
int32_t dimension = 1;
Tensor<int32_t, 2> expected_result{6, 15};
Tensor<int32_t, 2> result =
tosa::reduce_sum<Tensor<int32_t, 2>>(input, dimension);
Expand All @@ -1155,7 +1155,7 @@ TEST(tosa, reduce_sum) {
{
Tensor<int32_t, 4, 2, 3> input{1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,
1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6};
int64_t dimension = 0;
int32_t dimension = 0;
Tensor<int32_t, 2, 3> expected_result{4, 8, 12, 16, 20, 24};
Tensor<int32_t, 2, 3> result =
tosa::reduce_sum<Tensor<int32_t, 2, 3>>(input, dimension);
Expand All @@ -1165,7 +1165,7 @@ TEST(tosa, reduce_sum) {
{
Tensor<int32_t, 4, 2, 3> input{1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,
1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6};
int64_t dimension = 1;
int32_t dimension = 1;
Tensor<int32_t, 4, 3> expected_result{5, 7, 9, 5, 7, 9, 5, 7, 9, 5, 7, 9};
Tensor<int32_t, 4, 3> result =
tosa::reduce_sum<Tensor<int32_t, 4, 3>>(input, dimension);
Expand All @@ -1175,7 +1175,7 @@ TEST(tosa, reduce_sum) {
{
Tensor<int32_t, 4, 2, 3> input{1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,
1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6};
int64_t dimension = 2;
int32_t dimension = 2;
Tensor<int32_t, 4, 2> expected_result{6, 15, 6, 15, 6, 15, 6, 15};
Tensor<int32_t, 4, 2> result =
tosa::reduce_sum<Tensor<int32_t, 4, 2>>(input, dimension);
Expand Down

0 comments on commit 0fade03

Please sign in to comment.