diff --git a/src/frontends/pytorch/src/transforms/fftn_complex_replacer.cpp b/src/frontends/pytorch/src/transforms/fftn_complex_replacer.cpp index d5fc5ce8f0f662..1ba4989e7a2d9a 100644 --- a/src/frontends/pytorch/src/transforms/fftn_complex_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/fftn_complex_replacer.cpp @@ -11,12 +11,13 @@ #include "openvino/op/equal.hpp" #include "openvino/op/gather.hpp" #include "openvino/op/range.hpp" -#include "openvino/op/rdft.hpp" +#include "openvino/op/dft.hpp" #include "openvino/op/reduce_prod.hpp" #include "openvino/op/reshape.hpp" #include "openvino/op/select.hpp" #include "openvino/op/shape_of.hpp" #include "openvino/op/slice.hpp" +#include "openvino/op/split.hpp" #include "openvino/op/sqrt.hpp" #include "openvino/op/squeeze.hpp" #include "openvino/op/subtract.hpp" @@ -102,22 +103,22 @@ FFTNComplexReplacer::FFTNComplexReplacer() { return false; } - auto rdft = std::make_shared(input, dim, s); + auto dft = std::make_shared(input, dim, s); // Apply normalizations auto n_int = std::make_shared(s, const_0); - auto n = std::make_shared(n_int, rdft); + auto n = std::make_shared(n_int, dft); std::shared_ptr normalized_fftn; if (norm == "forward") { // Normalize by 1/n - normalized_fftn = std::make_shared(rdft, n); + normalized_fftn = std::make_shared(dft, n); } else if (norm == "backward") { // No normalization - normalized_fftn = rdft; + normalized_fftn = dft; } else if (norm == "ortho") { // Normalize by 1/sqrt(n) auto sqrt_n = std::make_shared(n); - normalized_fftn = std::make_shared(rdft, sqrt_n); + normalized_fftn = std::make_shared(dft, sqrt_n); } else { add_exception_to_fw_node( fftn_op, @@ -127,21 +128,12 @@ FFTNComplexReplacer::FFTNComplexReplacer() { // Replace outputs that are either torch operators aten::real or aten::imag. Apply squeeze to remove last // dimension used to concatenate. + auto normalized_rfftn_splitted = std::make_shared(normalized_fftn, const_neg_1, 2); auto fftn_outs = fftn_op->get_users(); bool rval = false; for (auto& out : fftn_outs) { if (auto real_op = cast_fw_node(out, "aten::real")) { - // Check if the last dimension of the tensor is 1 - auto last_dim = normalized_fftn->output(0).get_shape().back(); - if (last_dim != 1) { - // Reshape the tensor to make the last dimension 1 - auto new_shape = normalized_fftn->output(0).get_shape(); - new_shape.back() = 1; - auto new_shape_tensor = std::make_shared(element::i64, Shape{new_shape.size()}, new_shape); - auto reshaped = std::make_shared(normalized_fftn->output(0), new_shape_tensor, false); - normalized_fftn = reshaped; - } - auto squeezed = std::make_shared(normalized_fftn->output(0), const_neg_1); + auto squeezed = std::make_shared(normalized_rfftn_splitted->output(0), const_neg_1); copy_runtime_info({fftn_op, real_op}, squeezed); squeezed->set_friendly_name(real_op->get_friendly_name()); replace_node(real_op, squeezed);