diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index d59fad498..2bf91372f 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -2969,12 +2969,12 @@ def forward_and_backward_from_trace(trace: Trace, torch_autograd=False) -> Forwa Example: >>> import torch - >>> from thunder import compile, last_traces + >>> from thunder import jit, last_traces >>> from thunder.core.transforms import forward_and_backward_from_trace >>> def f(x): ... return torch.sin(x) >>> x = torch.tensor(3.0) - >>> cf = compile(f) + >>> cf = jit(f) >>> out = cf(x) >>> trace = last_traces(cf)[0] >>> forward_and_backward_from_trace(trace)