diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 5dd5ad57a444c..9e77eae0275a1 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -2934,6 +2934,20 @@ def forward(self, x, y): ) self.check_model(m, args) + def test_custom_op_add_output_path(self) -> None: + class M(torch.nn.Module): + def forward(self, x, y): + return torch.ops.aoti_custom_ops.custom_add(x, y) + + m = M().to(device=self.device) + args = ( + torch.randn(3, 3, device=self.device), + torch.randn(3, 3, device=self.device), + ) + with config.patch("aot_inductor.output_path", "model.so"): + with self.assertRaises(Exception): + self.check_model(m, args) + def test_custom_op_all_inputs(self) -> None: class MyModel(torch.nn.Module): # pyre-fixme[3]: Return type must be annotated. diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner.cpp b/torch/csrc/inductor/aoti_runner/model_container_runner.cpp index ac3422843a93c..af94e61d1bbe9 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner.cpp +++ b/torch/csrc/inductor/aoti_runner/model_container_runner.cpp @@ -74,6 +74,8 @@ AOTIModelContainerRunner::AOTIModelContainerRunner( json_filename, device_str == "cpu"); proxy_executor_handle_ = reinterpret_cast(proxy_executor_.get()); + } else { + proxy_executor_handle_ = nullptr; } AOTI_RUNTIME_ERROR_CODE_CHECK(create_func_( diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 474ea60594dfa..9e2494c818b39 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -1136,6 +1136,13 @@ AOTITorchError aoti_torch_proxy_executor_call_function( int num_tensors, AtenTensorHandle* flatten_tensor_args) { AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + if (!proxy_executor) { + throw std::runtime_error( + "Unable to find a proxy executor to run custom ops. Please check if " + "there is a json file generated in the same directory as the so, or use " + "torch._inductor.aoti_compile_and_package to package everything into a " + "PT2 artifact."); + } ProxyExecutor* executor = reinterpret_cast(proxy_executor); executor->call_function( extern_node_index,