Skip to content

Commit

Permalink
[aoti] Add error msg if we can't find a proxy executor (pytorch#140308)
Browse files Browse the repository at this point in the history
Fixes #ISSUE_NUMBER

Pull Request resolved: pytorch#140308
Approved by: https://github.com/desertfire
  • Loading branch information
angelayi authored and pytorchmergebot committed Nov 13, 2024
1 parent c61ccaf commit e754611
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 0 deletions.
14 changes: 14 additions & 0 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2934,6 +2934,20 @@ def forward(self, x, y):
)
self.check_model(m, args)

def test_custom_op_add_output_path(self) -> None:
class M(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aoti_custom_ops.custom_add(x, y)

m = M().to(device=self.device)
args = (
torch.randn(3, 3, device=self.device),
torch.randn(3, 3, device=self.device),
)
with config.patch("aot_inductor.output_path", "model.so"):
with self.assertRaises(Exception):
self.check_model(m, args)

def test_custom_op_all_inputs(self) -> None:
class MyModel(torch.nn.Module):
# pyre-fixme[3]: Return type must be annotated.
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/inductor/aoti_runner/model_container_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ AOTIModelContainerRunner::AOTIModelContainerRunner(
json_filename, device_str == "cpu");
proxy_executor_handle_ =
reinterpret_cast<AOTIProxyExecutorHandle>(proxy_executor_.get());
} else {
proxy_executor_handle_ = nullptr;
}

AOTI_RUNTIME_ERROR_CODE_CHECK(create_func_(
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/inductor/aoti_torch/shim_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,13 @@ AOTITorchError aoti_torch_proxy_executor_call_function(
int num_tensors,
AtenTensorHandle* flatten_tensor_args) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
if (!proxy_executor) {
throw std::runtime_error(
"Unable to find a proxy executor to run custom ops. Please check if "
"there is a json file generated in the same directory as the so, or use "
"torch._inductor.aoti_compile_and_package to package everything into a "
"PT2 artifact.");
}
ProxyExecutor* executor = reinterpret_cast<ProxyExecutor*>(proxy_executor);
executor->call_function(
extern_node_index,
Expand Down

0 comments on commit e754611

Please sign in to comment.