forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_interpreter.cpp
31 lines (25 loc) · 1 KB
/
test_interpreter.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#include "test/cpp/jit/test_base.h"
#include "test/cpp/jit/test_utils.h"
namespace torch {
namespace jit {
void testInterp() {
constexpr int batch_size = 4;
constexpr int input_size = 256;
constexpr int seq_len = 32;
int hidden_size = 2 * input_size;
auto input = at::randn({seq_len, batch_size, input_size}, at::kCUDA);
auto hx = at::randn({batch_size, hidden_size}, at::kCUDA);
auto cx = at::randn({batch_size, hidden_size}, at::kCUDA);
auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA));
auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA));
auto lstm_g = build_lstm();
Code lstm_function(lstm_g, "");
InterpreterState lstm_interp(lstm_function);
auto outputs = run(lstm_interp, {input[0], hx, cx, w_ih, w_hh});
std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
// std::cout << almostEqual(outputs[0],hx) << "\n";
ASSERT_TRUE(exactlyEqual(outputs[0], hx));
ASSERT_TRUE(exactlyEqual(outputs[1], cx));
}
} // namespace jit
} // namespace torch