Skip to content

Commit

Permalink
[cker] update forwarding function
Browse files Browse the repository at this point in the history
using existing forwarding function for averagePool2D.

ONE-DCO-1.0-Signed-off-by: JuYoung Lee [email protected]
  • Loading branch information
icodo98 committed Sep 28, 2024
1 parent bca4c85 commit 9536d4e
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ namespace cker
namespace train
{

inline void AvgPool2DGrad(const PoolParams &params, const Shape &incoming_shape,
const float *incoming_data, const Shape &grad_shape, float *grad_data)
inline void AveragePool2DGrad(const PoolParams &params, const Shape &incoming_shape,
const float *incoming_data, const Shape &grad_shape, float *grad_data)
{
assert(grad_shape.DimensionsCount() == 4);
assert(incoming_shape.DimensionsCount() == 4);
Expand Down
58 changes: 13 additions & 45 deletions compute/cker/src/train/AvgPool.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#include <cker/eigen/Utils.h>
#include <cker/operation/AveragePool.h>
#include <cker/train/operation/AvgPool.h>
#include <cker/train/operation/AveragePool.h>
#include <cker/Shape.h>

#include <gtest/gtest.h>
Expand Down Expand Up @@ -48,8 +48,8 @@ template <typename T> class AvgPoolOpVerifier
assert(expected_output.size() == _out_shape.FlatSize());

std::vector<T> cacluated_output(_out_shape.FlatSize());
nnfw::cker::AveragePool(_op_params, _in_shape, input.data(), _out_shape,
cacluated_output.data());
nnfw::cker::AveragePool<float>(_op_params, _in_shape, input.data(), _out_shape,
cacluated_output.data());

if (expect_eq)
EXPECT_EQ(expected_output, cacluated_output);
Expand All @@ -64,8 +64,8 @@ template <typename T> class AvgPoolOpVerifier
assert(expected_grad_data.size() == _in_shape.FlatSize());

std::vector<T> calcuated_grad(_in_shape.FlatSize());
nnfw::cker::train::AvgPool2DGrad(_op_params, _out_shape, incoming_data.data(), _in_shape,
calcuated_grad.data());
nnfw::cker::train::AveragePool2DGrad(_op_params, _out_shape, incoming_data.data(), _in_shape,
calcuated_grad.data());

if (expect_eq)
{
Expand Down Expand Up @@ -94,6 +94,8 @@ TEST(CKer_Operation, AvgPool2D)
op_param.filter_width = 2;
op_param.padding_values.height = 0;
op_param.padding_values.width = 0;
op_param.float_activation_max = std::numeric_limits<float>::max();
op_param.float_activation_min = std::numeric_limits<float>::lowest();
}
nnfw::cker::Shape in = {1, 3, 3, 1};
nnfw::cker::Shape out = {1, 2, 2, 1};
Expand Down Expand Up @@ -136,6 +138,8 @@ TEST(CKer_Operation, AvgPool2D)
op_param.filter_width = 3;
op_param.padding_values.height = 0;
op_param.padding_values.width = 0;
op_param.float_activation_max = std::numeric_limits<float>::max();
op_param.float_activation_min = std::numeric_limits<float>::lowest();
}
nnfw::cker::Shape in = {1, 3, 3, 2};
nnfw::cker::Shape out = {1, 1, 1, 2};
Expand Down Expand Up @@ -189,46 +193,6 @@ TEST(CKer_Operation, AvgPool2D)
/* depth1 */ 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04;
verifier.verifyBackward(output_deriv, expected_input_deriv);
}

// with padding case
{
nnfw::cker::PoolParams op_param;
{
op_param.stride_height = 2;
op_param.stride_width = 2;
op_param.filter_height = 2;
op_param.filter_width = 2;
op_param.padding_values.height = 2;
op_param.padding_values.width = 2;
}
nnfw::cker::Shape in = {1, 2, 2, 1};
nnfw::cker::Shape out = {1, 3, 3, 1};

AvgPoolOpVerifier<float> verifier(op_param, in, out);

/**
* input_with_padding: expected_output:
*
* 4 8 0 0 0
* 9 2 -(forward)-> 0 5.75 0
* 0 0 0
*/

std::vector<float> input = {4, 8, 9, 2};
std::vector<float> expected_output = {0, 0, 0, 0, 5.75, 0, 0, 0, 0};
verifier.verifyForward(input, expected_output);

/**
* output_deriv: input_deriv:
*
* 0.1 0.1 0.1 0.1 0.1
* 0.1 0.4 0.3 -(backward)-> 0.1 0.1
* 0.5 0.1 0.1
*/
std::vector<float> output_deriv = {0.1, 0.1, 0.1, 0.1, 0.4, 0.3, 0.5, 0.1, 0.1};
std::vector<float> expected_input_deriv = {0.1, 0.1, 0.1, 0.1};
verifier.verifyBackward(output_deriv, expected_input_deriv);
}
}

TEST(CKer_Operation, neg_AvgPoolInvalidExpectedValue)
Expand All @@ -243,6 +207,8 @@ TEST(CKer_Operation, neg_AvgPoolInvalidExpectedValue)
op_param.filter_width = 2;
op_param.padding_values.height = 0;
op_param.padding_values.width = 0;
op_param.float_activation_max = std::numeric_limits<float>::max();
op_param.float_activation_min = std::numeric_limits<float>::lowest();
}
nnfw::cker::Shape in = {1, 2, 2, 1};
nnfw::cker::Shape out = {1, 1, 1, 1};
Expand All @@ -265,6 +231,8 @@ TEST(CKer_Operation, neg_AvgPoolInvalidExpectedValue)
op_param.filter_width = 2;
op_param.padding_values.height = 1;
op_param.padding_values.width = 1;
op_param.float_activation_max = std::numeric_limits<float>::max();
op_param.float_activation_min = std::numeric_limits<float>::lowest();
}

nnfw::cker::Shape in = {1, 2, 2, 1};
Expand Down
31 changes: 16 additions & 15 deletions runtime/onert/backend/train/ops/PoolLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
#include "../Tensor.h"

#include <cker/Utils.h>
#include <cker/operation/AveragePool.h>
#include <cker/train/operation/MaxPool.h>
#include <cker/train/operation/AvgPool.h>
#include <cker/train/operation/AveragePool.h>
#include <cker/train/operation/ReLU.h>

namespace onert
Expand Down Expand Up @@ -109,7 +110,7 @@ class MaxPool2D final : public TrainingKernelRegistry
}
};

class AvgPool2D final : public TrainingKernelRegistry
class AveragePool2D final : public TrainingKernelRegistry
{
private:
const ir::Activation _activation;
Expand All @@ -120,10 +121,10 @@ class AvgPool2D final : public TrainingKernelRegistry
std::unique_ptr<Tensor> _arg_avg_index;

public:
AvgPool2D(const uint32_t paddingLeft, const uint32_t, const uint32_t paddingTop, const uint32_t,
const uint32_t strideWidth, const uint32_t strideHeight, const uint32_t kernelWidth,
const uint32_t kernelHeight, const ir::Activation activation,
const IPortableTensor *output)
AveragePool2D(const uint32_t paddingLeft, const uint32_t, const uint32_t paddingTop,
const uint32_t, const uint32_t strideWidth, const uint32_t strideHeight,
const uint32_t kernelWidth, const uint32_t kernelHeight,
const ir::Activation activation, const IPortableTensor *output)
: _activation(activation), _output(output)
{
{
Expand All @@ -144,7 +145,7 @@ class AvgPool2D final : public TrainingKernelRegistry
}
};

~AvgPool2D() {}
~AveragePool2D() {}

public:
void forward(const IPortableTensor *in, IPortableTensor *out)
Expand All @@ -153,8 +154,8 @@ class AvgPool2D final : public TrainingKernelRegistry
auto out_data = getBuffer<float>(out);

// avgpool forward
nnfw::cker::train::AvgPool2D(_op_params, getShape(in), getBuffer<float>(in), out_shape,
out_data);
nnfw::cker::AveragePool<float>(_op_params, getShape(in), getBuffer<float>(in), out_shape,
out_data);
}

void backward(const IPortableTensor *back_prop_out, IPortableTensor *back_prop_in)
Expand All @@ -172,9 +173,9 @@ class AvgPool2D final : public TrainingKernelRegistry
assert(back_prop_out != nullptr);

// averagepool baackward
nnfw::cker::train::AvgPool2DGrad(_op_params, getShape(back_prop_out),
getBuffer<float>(back_prop_out), getShape(back_prop_in),
getBuffer<float>(back_prop_in));
nnfw::cker::train::AveragePool2DGrad(_op_params, getShape(back_prop_out),
getBuffer<float>(back_prop_out), getShape(back_prop_in),
getBuffer<float>(back_prop_in));
}
};

Expand Down Expand Up @@ -211,9 +212,9 @@ void PoolLayer::configureBackward(const uint32_t paddingLeft, const uint32_t pad
activation, output);
break;
case PoolType::kAvg:
_kernel = std::make_unique<AvgPool2D>(paddingLeft, paddingRight, paddingTop, paddingBottom,
strideWidth, strideHeight, kernelWidth, kernelHeight,
activation, output);
_kernel = std::make_unique<AveragePool2D>(paddingLeft, paddingRight, paddingTop,
paddingBottom, strideWidth, strideHeight,
kernelWidth, kernelHeight, activation, output);
break;
default:
throw std::runtime_error("PoolLayer: Unsupported pool type");
Expand Down

0 comments on commit 9536d4e

Please sign in to comment.