Skip to content

Commit

Permalink
Update fftn_complex_replacer.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
siddhant-0707 committed Jan 18, 2024
1 parent 75efbda commit 020b06f
Showing 1 changed file with 9 additions and 17 deletions.
26 changes: 9 additions & 17 deletions src/frontends/pytorch/src/transforms/fftn_complex_replacer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
#include "openvino/op/equal.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/rdft.hpp"
#include "openvino/op/dft.hpp"
#include "openvino/op/reduce_prod.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/split.hpp"
#include "openvino/op/sqrt.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/subtract.hpp"
Expand Down Expand Up @@ -102,22 +103,22 @@ FFTNComplexReplacer::FFTNComplexReplacer() {
return false;
}

auto rdft = std::make_shared<v9::RDFT>(input, dim, s);
auto dft = std::make_shared<v7::DFT>(input, dim, s);

// Apply normalizations
auto n_int = std::make_shared<v1::ReduceProd>(s, const_0);
auto n = std::make_shared<v1::ConvertLike>(n_int, rdft);
auto n = std::make_shared<v1::ConvertLike>(n_int, dft);
std::shared_ptr<ov::Node> normalized_fftn;
if (norm == "forward") {
// Normalize by 1/n
normalized_fftn = std::make_shared<v1::Divide>(rdft, n);
normalized_fftn = std::make_shared<v1::Divide>(dft, n);
} else if (norm == "backward") {
// No normalization
normalized_fftn = rdft;
normalized_fftn = dft;
} else if (norm == "ortho") {
// Normalize by 1/sqrt(n)
auto sqrt_n = std::make_shared<v0::Sqrt>(n);
normalized_fftn = std::make_shared<v1::Divide>(rdft, sqrt_n);
normalized_fftn = std::make_shared<v1::Divide>(dft, sqrt_n);
} else {
add_exception_to_fw_node(
fftn_op,
Expand All @@ -127,21 +128,12 @@ FFTNComplexReplacer::FFTNComplexReplacer() {

// Replace outputs that are either torch operators aten::real or aten::imag. Apply squeeze to remove last
// dimension used to concatenate.
auto normalized_rfftn_splitted = std::make_shared<v1::Split>(normalized_fftn, const_neg_1, 2);
auto fftn_outs = fftn_op->get_users();
bool rval = false;
for (auto& out : fftn_outs) {
if (auto real_op = cast_fw_node(out, "aten::real")) {
// Check if the last dimension of the tensor is 1
auto last_dim = normalized_fftn->output(0).get_shape().back();
if (last_dim != 1) {
// Reshape the tensor to make the last dimension 1
auto new_shape = normalized_fftn->output(0).get_shape();
new_shape.back() = 1;
auto new_shape_tensor = std::make_shared<v0::Constant>(element::i64, Shape{new_shape.size()}, new_shape);
auto reshaped = std::make_shared<v1::Reshape>(normalized_fftn->output(0), new_shape_tensor, false);
normalized_fftn = reshaped;
}
auto squeezed = std::make_shared<v0::Squeeze>(normalized_fftn->output(0), const_neg_1);
auto squeezed = std::make_shared<v0::Squeeze>(normalized_rfftn_splitted->output(0), const_neg_1);
copy_runtime_info({fftn_op, real_op}, squeezed);
squeezed->set_friendly_name(real_op->get_friendly_name());
replace_node(real_op, squeezed);
Expand Down

0 comments on commit 020b06f

Please sign in to comment.