Skip to content

Commit

Permalink
[docs] Update examples in the PyTorch+IREE guide (iree-org#18620)
Browse files Browse the repository at this point in the history
Progresses towards iree-org#18564.

Fixes an example involving `shark_turbine.aot` that produced a segfault
due to a call to `VmModule.wrap_buffer()` when run.

Also removes examples related to `aot.jittable` that were present in the
doc, and updates the recommended PyTorch version.
  • Loading branch information
vinayakdsci authored Sep 27, 2024
1 parent b5b4ab7 commit 2e382a7
Showing 1 changed file with 4 additions and 88 deletions.
92 changes: 4 additions & 88 deletions docs/website/docs/guides/ml-frameworks/pytorch.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,12 @@ graph LR

## :octicons-download-16: Prerequisites

Install a recent version of PyTorch (`2.3.0+`, prerelease as of April 2024):
Install a recent version of PyTorch
(`2.4.1`, latest stable release as of September 2024):

``` shell
python -m pip install \
--pre --index-url https://download.pytorch.org/whl/test/cpu torch==2.3.0
--index-url https://download.pytorch.org/whl/test/cpu torch==2.4.1
```

Install iree-turbine:
Expand Down Expand Up @@ -218,7 +219,7 @@ binary = export_output.compile(save_to=None)
# Use the IREE runtime API to test the compiled program.
config = ireert.Config("local-task")
vm_module = ireert.load_vm_module(
ireert.VmModule.wrap_buffer(config.vm_instance, binary.map_memory()),
ireert.VmModule.copy_buffer(config.vm_instance, binary.map_memory()),
config,
)
input = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
Expand Down Expand Up @@ -297,59 +298,6 @@ exported function (typically called "forward"), while more complex programs can
have multiple computation functions, initialization functions, "backward"
methods for training, state management functions, debugging functions, etc.

* Each instance method on a `aot.CompiledModule`-derived class is exported.
These instance methods can include calls to other `aot` components, such as
`aot.jittable` compute functions:

```python
class GetOnesModule(aot.CompiledModule):
@aot.jittable
def compute_ones():
return torch.ones(3)

def get_ones(self):
return self.compute_ones()
```

* Instance methods can use `aot.AbstractTensor` to specify data types:

```python hl_lines="8-9"
class IntSumModule(aot.CompiledModule):
@aot.jittable
def compute_sum(a, b):
return a + b

def sum_int32(
self,
a=aot.AbstractTensor(2, dtype=torch.int32),
b=aot.AbstractTensor(2, dtype=torch.int32),
):
return self.compute_sum(a, b)
```

* Shapes can be made dynamic using `aot.AbstractTensor` and `aot.jittable`
constraints:

```python hl_lines="8-9 14-16"
class DynamicSumModule(aot.CompiledModule):
@aot.jittable
def compute_sum(a, b):
return a + b

def sum_dynamic(
self,
a=aot.AbstractTensor(None),
b=aot.AbstractTensor(None),
):
return self.compute_sum(
a,
b,
constraints=[
a.dynamic_dim(0) == b.dynamic_dim(0),
],
)
```

#### :material-variable: Global variables

_Global variables_ are used to represent persistent state within a program
Expand All @@ -374,38 +322,6 @@ their values independently at runtime.
self.value = new_value
```

* All named parameters on a `nn.Module` can be exported using
`export_parameters()`:

```python hl_lines="12 18-26"
class SimpleParams(torch.nn.Module):
def __init__(self):
super().__init__()
self.classifier = torch.nn.Linear(20, 30)

def forward(self, x):
return self.classifier(x)

m = SimpleParams()

class SimpleParamsModule(aot.CompiledModule):
params = aot.export_parameters(m)
compute = aot.jittable(m.forward)

def run(self, x=aot.AbstractTensor(128, 20)):
return self.compute(x)

# torch.nn.Linear has 'weight' and 'bias' variables:
# https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
# Add getters for both exported parameters.

def get_weight(self):
return self.params["classifier.weight"]

def get_bias(self):
return self.params["classifier.bias"]
```

#### :octicons-code-16: Samples

| Code samples | |
Expand Down

0 comments on commit 2e382a7

Please sign in to comment.