Skip to content

Commit

Permalink
[JIT] python IR bindings: consolidate tests, add short docs in OVERVI…
Browse files Browse the repository at this point in the history
…EW.md (pytorch#118319)

Document the existence of python IR bindings; quick comments about it; and consolidate tests in one file to serve as examples to users.
Pull Request resolved: pytorch#118319
Approved by: https://github.com/eellison
  • Loading branch information
davidberard98 authored and pytorchmergebot committed Jan 27, 2024
1 parent 9bce208 commit 40c0879
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 68 deletions.
72 changes: 72 additions & 0 deletions test/jit/test_python_ir.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# Owner(s): ["oncall: jit"]

import torch
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import IS_MACOS

import numpy as np
import unittest

if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
Expand All @@ -18,3 +23,70 @@ def trace_me(arg):
real_strides = list(t.stride())
type_strides = value.type().strides()
self.assertEqual(real_strides, type_strides)

def test_permute_inputs_binding(self):
@torch.jit.script
def foo(i, j, k):
pass

g = foo.graph

idxs = []
for i, inp in enumerate(g.inputs()):
inp.setDebugName(f"inp{i}")
idxs.append(i)

permuted_idxs = list(np.random.permutation(idxs))
g.permuteInputs(permuted_idxs)
for i, inp in enumerate(g.inputs()):
self.assertEqual(f"inp{permuted_idxs[i]}", inp.debugName())

@unittest.skipIf(IS_MACOS, "Failing on MacOS only")
def test_python_ir_utils(self):
@torch.jit.script
def foo(inp):
x = inp + 1
y = x / 2
z = y * y
return z

add_node = foo.graph.findNode("aten::add")
div_node = foo.graph.findNode("aten::div")

with foo.graph.insert_point_guard(add_node):
with foo.graph.insert_point_guard(div_node):
foo.graph.insertConstant("goodbye")
foo.graph.insertConstant("hello")
with foo.graph.insert_point_guard(foo.graph.findNode("aten::mul")):
foo.graph.insertConstant("hello")
FileCheck().check("hello").check("goodbye").check("hello").run(foo.graph)

self.assertTrue(add_node.matches(add_node.schema()))
self.assertFalse(add_node.matches(div_node.schema()))

def test_python_ir_utils_graph(self):
@torch.jit.script
def unrolled_mul(x: torch.Tensor, y: int):
out = x
for _ in range(y - 1):
out = out + x
return out

@torch.jit.script
def foo(x):
return x * 4

g = foo.graph
muls = g.findAllNodes("aten::mul")
scalar_muls = filter(lambda x: x.matches("aten::mul(Tensor self, Scalar other) -> Tensor"), muls)
mul_constant_int = filter(lambda x: isinstance(list(x.inputs())[1].toIValue(), int), scalar_muls)
for mul in mul_constant_int:
with g.insert_point_guard(mul):
outputs = g.insertGraph(unrolled_mul.graph, list(mul.inputs()))
assert len(outputs) == len(list(mul.outputs()))
for new_out, old_out in zip(outputs, g.outputs()):
old_out.replaceAllUsesWith(new_out)
mul.destroy()

FileCheck().check_not("aten::mul").check("aten::add").run(foo.graph)
self.assertEqual(foo(torch.ones([2, 2])), torch.ones([2, 2]) * 4)
69 changes: 1 addition & 68 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
suppress_warnings, BUILD_WITH_CAFFE2, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, TestCase, \
freeze_rng_state, slowTest, TemporaryFileName, \
enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \
skipIfCrossRef, IS_MACOS, skipIfTorchDynamo
skipIfCrossRef, skipIfTorchDynamo
from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, disable_autodiff_subgraph_inlining, \
_trace, do_input_map, get_execution_plan, make_global, \
execWrapper, _inline_everything, _tmp_donotuse_dont_inline_everything, \
Expand Down Expand Up @@ -1763,73 +1763,6 @@ def doit(x, y):
for node in g.nodes():
self.assertTrue(g2.findNode(node.kind()) is not None)

def test_permute_inputs_binding(self):
@torch.jit.script
def foo(i, j, k):
pass

g = foo.graph

idxs = []
for i, inp in enumerate(g.inputs()):
inp.setDebugName(f"inp{i}")
idxs.append(i)

permuted_idxs = list(np.random.permutation(idxs))
g.permuteInputs(permuted_idxs)
for i, inp in enumerate(g.inputs()):
self.assertEqual(f"inp{permuted_idxs[i]}", inp.debugName())

@unittest.skipIf(IS_MACOS, "Failing on MacOS only")
def test_python_ir_utils(self):
@torch.jit.script
def foo(inp):
x = inp + 1
y = x / 2
z = y * y
return z

add_node = foo.graph.findNode("aten::add")
div_node = foo.graph.findNode("aten::div")

with foo.graph.insert_point_guard(add_node):
with foo.graph.insert_point_guard(div_node):
foo.graph.insertConstant("goodbye")
foo.graph.insertConstant("hello")
with foo.graph.insert_point_guard(foo.graph.findNode("aten::mul")):
foo.graph.insertConstant("hello")
FileCheck().check("hello").check("goodbye").check("hello").run(foo.graph)

self.assertTrue(add_node.matches(add_node.schema()))
self.assertFalse(add_node.matches(div_node.schema()))

def test_python_ir_utils_graph(self):
@torch.jit.script
def unrolled_mul(x: torch.Tensor, y: int):
out = x
for _ in range(y - 1):
out = out + x
return out

@torch.jit.script
def foo(x):
return x * 4

g = foo.graph
muls = g.findAllNodes("aten::mul")
scalar_muls = filter(lambda x: x.matches("aten::mul(Tensor self, Scalar other) -> Tensor"), muls)
mul_constant_int = filter(lambda x: isinstance(list(x.inputs())[1].toIValue(), int), scalar_muls)
for mul in mul_constant_int:
with g.insert_point_guard(mul):
outputs = g.insertGraph(unrolled_mul.graph, list(mul.inputs()))
assert len(outputs) == len(list(mul.outputs()))
for new_out, old_out in zip(outputs, g.outputs()):
old_out.replaceAllUsesWith(new_out)
mul.destroy()

FileCheck().check_not("aten::mul").check("aten::add").run(foo.graph)
self.assertEqual(foo(torch.ones([2, 2])), torch.ones([2, 2]) * 4)

@unittest.skipIf(IS_SANDCASTLE, "gtest runs these in sandcastle")
@unittest.skipIf(RUN_CUDA, "covered by test_cpp_cuda")
@unittest.skipIf(not torch._C._jit_has_cpp_tests(), "Tests were not built, use BUILD_TEST=1")
Expand Down
11 changes: 11 additions & 0 deletions torch/csrc/jit/OVERVIEW.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ Sections start with a reference to the source file where the code related to the
- [Testing Autodiff](#testing-autodiff)
- [Python Printer](#python-printer)
- [Python Bindings](#python-bindings)
- [Graph Manipulation](#graph-manipulation)

<!-- tocstop -->

Expand Down Expand Up @@ -1522,3 +1523,13 @@ def forward(self,
# Python Bindings

TODO: Script Module, torch.jit.trace, __constant__ handling, weak script modules

## Graph Manipulation

Python bindings for manipulating TorchScript IR exists in [python_ir.cpp](https://github.com/pytorch/pytorch/blob/58e7ec5843e63ee044e0a4f5aa2583a056a64078/torch/csrc/jit/python/python_ir.cpp#L4). In general, graph structures should look the same as the representation described above in [Core Program Representation](#core-program-representation).

Things to watch out for:
* You may need to first inline your graph (`torch._C._jit_pass_inline`) or recursively traverse CallFunction nodes (`for x in graph.findAllNodes("prim::CallFunction")`) if you want to recursively modify your graph and the functions it calls
* To insert a graph after node n, use the context manager `with graph.insert_point_guard(new_node)`

See more examples in [test_python_ir.py](https://github.com/pytorch/pytorch/blob/main/test/jit/test_python_ir.py)
3 changes: 3 additions & 0 deletions torch/testing/_internal/dynamo_test_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7454,6 +7454,9 @@
"TestSymbolicShapeAnalysis.test_if_propagation", # test_jit
"TestPeephole.test_normalized_rsub", # test_jit
"TestPythonIr.test_param_strides", # test_jit
"TestPythonIr.test_permute_inputs_binding", # test_jit
"TestPythonIr.test_python_ir_utils", # test_jit
"TestPythonIr.test_python_ir_utils_graph", # test_jit
"TestComplex.test_complex_list_sum", # test_jit
"TestUnion.test_union_redundant_arguments_are_skipped_optional", # test_jit
"TestNnapiBackend.test_conv2d", # test_jit
Expand Down

0 comments on commit 40c0879

Please sign in to comment.