From 5079321b715bb1a48c48cf78835639c0784ce2c9 Mon Sep 17 00:00:00 2001 From: Oleg Khabinov Date: Thu, 24 Mar 2022 20:09:32 -0700 Subject: [PATCH] Fix issue with prim::Print() and torch::deploy (#74513) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/74513 Reviewed By: d4l3k, houseroad Differential Revision: D35035089 fbshipit-source-id: d67b98600c74e2ed16b4d80f52148cd64b9e6ca0 (cherry picked from commit 16caf865077e28be31b805f015b9a61962632c8f) --- torch/csrc/deploy/test_deploy.cpp | 36 ++++++++++++++++++++++++ torch/csrc/jit/python/init.cpp | 6 ++++ torch/csrc/jit/runtime/print_handler.cpp | 12 ++++++-- torch/csrc/jit/runtime/print_handler.h | 3 +- 4 files changed, 53 insertions(+), 4 deletions(-) diff --git a/torch/csrc/deploy/test_deploy.cpp b/torch/csrc/deploy/test_deploy.cpp index 08948fbc6c57a..973fbff0fa4f2 100644 --- a/torch/csrc/deploy/test_deploy.cpp +++ b/torch/csrc/deploy/test_deploy.cpp @@ -482,6 +482,42 @@ TEST(TorchpyTest, TestPyYAML) { } #endif +TEST(TorchpyTest, PrintInstruction) { + const auto jit_script_with_print = R"JIT( + def forward(self, a): + print(a) + return a + a + )JIT"; + + auto input = torch::autograd::make_variable(at::randn({2, 3})); + auto expected_forward = input + input; + + auto module = std::make_shared( + "Module", std::make_shared()); + module->define(jit_script_with_print); + + std::vector inputs{at::IValue(input)}; + + // Checking that a module containing prim::Print() works fine. + auto result1 = (*module)(inputs); + EXPECT_TRUE(result1.toTensor().equal(expected_forward)); + + { + auto interpreterManager = + std::make_shared(1); + + // Checking that a module containing prim::Print() still works fine + // after Python environment was created. + auto result2 = (*module)(inputs); + EXPECT_TRUE(result2.toTensor().equal(expected_forward)); + } + + // Checking that a module containing prim::Print() still works fine + // after Python environment was created and then destroyed. + auto result3 = (*module)(inputs); + EXPECT_TRUE(result3.toTensor().equal(expected_forward)); +} + int main(int argc, char* argv[]) { ::testing::InitGoogleTest(&argc, argv); int rc = RUN_ALL_TESTS(); diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 0112ee320e975..9282ec38a322e 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1576,6 +1576,12 @@ void initJITBindings(PyObject* module) { throw std::runtime_error(e.what()); } }); + + // On exit we need to reset the print handler to default one, + // because otherwise prim::Print() instruction won't work for JIT modules. + auto atexit = py::module_::import("atexit"); + atexit.attr("register")( + py::cpp_function([]() { setPrintHandler(getDefaultPrintHandler()); })); } } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/runtime/print_handler.cpp b/torch/csrc/jit/runtime/print_handler.cpp index e3e8585bca56e..9452589f9e390 100644 --- a/torch/csrc/jit/runtime/print_handler.cpp +++ b/torch/csrc/jit/runtime/print_handler.cpp @@ -6,9 +6,15 @@ namespace torch { namespace jit { -std::atomic print_handler([](const std::string& str) { - std::cout << str; -}); +namespace { + +std::atomic print_handler(getDefaultPrintHandler()); + +} // namespace + +PrintHandler getDefaultPrintHandler() { + return [](const std::string& s) { std::cout << s; }; +} PrintHandler getPrintHandler() { return print_handler.load(); diff --git a/torch/csrc/jit/runtime/print_handler.h b/torch/csrc/jit/runtime/print_handler.h index d9ba851fa1abd..2f1f3ee92e069 100644 --- a/torch/csrc/jit/runtime/print_handler.h +++ b/torch/csrc/jit/runtime/print_handler.h @@ -11,8 +11,9 @@ namespace jit { using PrintHandler = void (*)(const std::string&); -TORCH_API void setPrintHandler(PrintHandler ph); +TORCH_API PrintHandler getDefaultPrintHandler(); TORCH_API PrintHandler getPrintHandler(); +TORCH_API void setPrintHandler(PrintHandler ph); } // namespace jit } // namespace torch