Skip to content

Commit

Permalink
Fix issue with prim::Print() and torch::deploy (pytorch#74513)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#74513

Reviewed By: d4l3k, houseroad

Differential Revision: D35035089

fbshipit-source-id: d67b98600c74e2ed16b4d80f52148cd64b9e6ca0
(cherry picked from commit 16caf86)
  • Loading branch information
khabinov authored and pytorchmergebot committed Mar 25, 2022
1 parent b347b8c commit 5079321
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 4 deletions.
36 changes: 36 additions & 0 deletions torch/csrc/deploy/test_deploy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,42 @@ TEST(TorchpyTest, TestPyYAML) {
}
#endif

TEST(TorchpyTest, PrintInstruction) {
const auto jit_script_with_print = R"JIT(
def forward(self, a):
print(a)
return a + a
)JIT";

auto input = torch::autograd::make_variable(at::randn({2, 3}));
auto expected_forward = input + input;

auto module = std::make_shared<torch::jit::Module>(
"Module", std::make_shared<at::CompilationUnit>());
module->define(jit_script_with_print);

std::vector<at::IValue> inputs{at::IValue(input)};

// Checking that a module containing prim::Print() works fine.
auto result1 = (*module)(inputs);
EXPECT_TRUE(result1.toTensor().equal(expected_forward));

{
auto interpreterManager =
std::make_shared<torch::deploy::InterpreterManager>(1);

// Checking that a module containing prim::Print() still works fine
// after Python environment was created.
auto result2 = (*module)(inputs);
EXPECT_TRUE(result2.toTensor().equal(expected_forward));
}

// Checking that a module containing prim::Print() still works fine
// after Python environment was created and then destroyed.
auto result3 = (*module)(inputs);
EXPECT_TRUE(result3.toTensor().equal(expected_forward));
}

int main(int argc, char* argv[]) {
::testing::InitGoogleTest(&argc, argv);
int rc = RUN_ALL_TESTS();
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/jit/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1576,6 +1576,12 @@ void initJITBindings(PyObject* module) {
throw std::runtime_error(e.what());
}
});

// On exit we need to reset the print handler to default one,
// because otherwise prim::Print() instruction won't work for JIT modules.
auto atexit = py::module_::import("atexit");
atexit.attr("register")(
py::cpp_function([]() { setPrintHandler(getDefaultPrintHandler()); }));
}
} // namespace jit
} // namespace torch
12 changes: 9 additions & 3 deletions torch/csrc/jit/runtime/print_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@
namespace torch {
namespace jit {

std::atomic<PrintHandler> print_handler([](const std::string& str) {
std::cout << str;
});
namespace {

std::atomic<PrintHandler> print_handler(getDefaultPrintHandler());

} // namespace

PrintHandler getDefaultPrintHandler() {
return [](const std::string& s) { std::cout << s; };
}

PrintHandler getPrintHandler() {
return print_handler.load();
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/runtime/print_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ namespace jit {

using PrintHandler = void (*)(const std::string&);

TORCH_API void setPrintHandler(PrintHandler ph);
TORCH_API PrintHandler getDefaultPrintHandler();
TORCH_API PrintHandler getPrintHandler();
TORCH_API void setPrintHandler(PrintHandler ph);

} // namespace jit
} // namespace torch

0 comments on commit 5079321

Please sign in to comment.