Skip to content

Commit

Permalink
[CPU][ARM64] Implemented JIT Emitter for Eltwise Squared Difference O…
Browse files Browse the repository at this point in the history
…peration (#28989)

### Details:
- Implemented and added jit_squared_difference_emitter derived class for
element wise squared difference operation
- Added entry Algorithm::EltwiseSquaredDifference, in executors/aarch64
as one of the supported algorithms
- Added entry in the get_supported_precisions and
create_eltwise_emitters in kernel/aarch64
- Added `utils::EltwiseTypes::SQUARED_DIFF` in `jit` kernel check in the
tests

### Tests:
Passed local tests using `./bin/arm64/Release/ov_cpu_func_tests
--gtest_filter="*smoke*Eltwise*SqDiff*"`
<img width="487" alt="Screenshot 2025-02-14 at 2 04 25 PM"
src="https://github.com/user-attachments/assets/deaec2f6-0dcd-4764-86ec-c543ac4742d3"
/>

### Tickets
- Closes #27502 

CC:
@a-sidorova
  • Loading branch information
srinjoydutta03 authored Feb 14, 2025
1 parent eb44f8d commit 9d92d9c
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -2704,6 +2704,49 @@ std::set<std::vector<element::Type>> jit_sqrt_emitter::get_supported_precisions(
return {{element::f32}};
}

/// SQUARED DIFFERENCE ///
jit_squared_difference_emitter::jit_squared_difference_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {}

jit_squared_difference_emitter::jit_squared_difference_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc)
: jit_emitter(host, host_isa, exec_prc) {}

size_t jit_squared_difference_emitter::get_inputs_count() const {
return 2;
}

void jit_squared_difference_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs,
const std::vector<size_t>& out_vec_idxs) const {
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
} else {
OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel");
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_squared_difference_emitter::emit_isa(const std::vector<size_t>& in_vec_idxs,
const std::vector<size_t>& out_vec_idxs) const {
OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string());

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg src0 = TReg(in_vec_idxs[0]);
TReg src1 = TReg(in_vec_idxs[1]);
TReg dst = TReg(out_vec_idxs[0]);

h->fsub(dst.s, src0.s, src1.s);
h->fmul(dst.s, dst.s, dst.s);
}

std::set<std::vector<element::Type>> jit_squared_difference_emitter::get_supported_precisions(
const std::shared_ptr<ov::Node>& node) {
return {{element::f32, element::f32}};
}

/// SUBTRACT ///
jit_subtract_emitter::jit_subtract_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,28 @@ class jit_sqrt_emitter : public jit_emitter {
void emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const;
};

class jit_squared_difference_emitter : public jit_emitter {
public:
jit_squared_difference_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc = ov::element::f32);

jit_squared_difference_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node);

size_t get_inputs_count() const override;

static std::set<std::vector<element::Type>> get_supported_precisions(
const std::shared_ptr<ov::Node>& node = nullptr);

private:
void emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const override;

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const;
};

class jit_subtract_emitter : public jit_emitter {
public:
jit_subtract_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ bool JitEltwiseExecutor::isSupported(const Algorithm& algorithm,
Algorithm::EltwiseSigmoid,
Algorithm::EltwiseSoftSign,
Algorithm::EltwiseSqrt,
Algorithm::EltwiseSquaredDifference,
Algorithm::EltwiseSubtract,
Algorithm::EltwiseSwish,
Algorithm::EltwiseTanh);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,7 @@ std::shared_ptr<jit_emitter> jit_uni_eltwise_generic<isa>::create_eltwise_emitte
OV_CASE(Algorithm::EltwiseSigmoid, ov::intel_cpu::aarch64::jit_sigmoid_emitter),
OV_CASE(Algorithm::EltwiseSoftSign, ov::intel_cpu::aarch64::jit_soft_sign_emitter),
OV_CASE(Algorithm::EltwiseSqrt, ov::intel_cpu::aarch64::jit_sqrt_emitter),
OV_CASE(Algorithm::EltwiseSquaredDifference, ov::intel_cpu::aarch64::jit_squared_difference_emitter),
OV_CASE(Algorithm::EltwiseSubtract, ov::intel_cpu::aarch64::jit_subtract_emitter),
OV_CASE(Algorithm::EltwiseSwish, ov::intel_cpu::aarch64::jit_swish_emitter),
OV_CASE(Algorithm::EltwiseTanh, ov::intel_cpu::aarch64::jit_tanh_emitter));
Expand Down Expand Up @@ -836,6 +837,7 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre
OV_CASE(Algorithm::EltwiseSigmoid, jit_sigmoid_emitter),
OV_CASE(Algorithm::EltwiseSoftSign, jit_soft_sign_emitter),
OV_CASE(Algorithm::EltwiseSqrt, jit_sqrt_emitter),
OV_CASE(Algorithm::EltwiseSquaredDifference, jit_squared_difference_emitter),
OV_CASE(Algorithm::EltwiseSubtract, jit_subtract_emitter),
OV_CASE(Algorithm::EltwiseSwish, jit_swish_emitter),
OV_CASE(Algorithm::EltwiseTanh, jit_tanh_emitter));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ std::string EltwiseLayerCPUTest::getPrimitiveType(const utils::EltwiseTypes& elt
(eltwise_type == utils::EltwiseTypes::SUBTRACT) ||
(eltwise_type == utils::EltwiseTypes::DIVIDE) ||
(eltwise_type == utils::EltwiseTypes::FLOOR_MOD) ||
(eltwise_type == utils::EltwiseTypes::MOD)) {
(eltwise_type == utils::EltwiseTypes::MOD) ||
(eltwise_type == utils::EltwiseTypes::SQUARED_DIFF)) {
return "jit";
}
#endif
Expand Down

0 comments on commit 9d92d9c

Please sign in to comment.