Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchao float8tensor] #1415

Draft
wants to merge 15 commits into
base: crpa/subclass-tensor-ops
Choose a base branch
from

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Nov 8, 2024

Before submitting
  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes # (issue).

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@crcrpar
Copy link
Collaborator Author

crcrpar commented Nov 8, 2024

As of abf0167

thunder/core/jit_ext.py:750: in _general_jit_torch_autograd_function_apply_lookaside
    @wraps(bwd_trace_impl.python_callable())
thunder/core/trace.py:497: in python_callable
    python_str = self.python(**kwargs)
thunder/core/trace.py:363: in python
    import_ctx, call_ctx, object_ctx = self._gather_ctxs()
thunder/core/trace.py:322: in _gather_ctxs
    bsym_import_ctx, bsym_call_ctx, bsym_object_ctx = bsym.gather_ctxs()
thunder/core/symbol.py:650: in gather_ctxs
    return self.import_ctx(), self._get_call_ctx(), self.object_ctx()
thunder/core/symbol.py:591: in import_ctx
    self._out_printables, self._arg_printables, self._kwarg_printables  # type: ignore
thunder/core/symbol.py:545: in _arg_printables
    return tuple(
thunder/core/symbol.py:546: in <genexpr>
    codeutils.to_printable(trace, x, import_ctx=self._import_ctx, object_ctx=self._object_ctx)
thunder/core/codeutils.py:144: in to_printable
    flat, spec = tree_flatten(x, namespace="")

Signed-off-by: Masaki Kozuki <[email protected]>
next, function with tensor creation in it

Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>

revert wrong patch

Signed-off-by: Masaki Kozuki <[email protected]>

supply unpacks with traces generated within the lookaside

Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
@crcrpar crcrpar force-pushed the crpa/subclass-torchao_float8tensor branch from abf0167 to e7ca8b7 Compare November 21, 2024 03:14
Signed-off-by: Masaki Kozuki <[email protected]>
@crcrpar
Copy link
Collaborator Author

crcrpar commented Nov 21, 2024

This is a portion of the initial trace as of dd075db:
Obviously the imports lack torchao.

# Constructed by Remove context manager prims
import thunder
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(input, weight):
  # input: "cuda:0 f32[16, 32]"
  # weight: "cuda:0 f32[64, 32]"

  # /home/mkozuki/.pyenv/versions/3.10.13/envs/torchdev-3.10/lib/python3.10/site-packages/torchao/float8/float8_utils.py:113: 	        amax = torch.max(torch.abs(x))
  t3 = ltorch.abs(input)  # t3: "cuda:0 f32[16, 32]"
    # t3 = prims.abs(input)  # t3: "cuda:0 f32[16, 32]"
  amax = ltorch.torch_max(t3, None, False)  # amax: "cuda:0 f32[]"
    # amax = ltorch.amax(t3, [0, 1], False)  # amax: "cuda:0 f32[]"
      # t3 = ltorch.to(t3, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None)  # t3: "cuda:0 f32[16, 32]"
      # amax = prims.amax(t3, (0, 1))  # amax: "cuda:0 f32[]"
      # amax = ltorch.to(amax, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None)  # amax: "cuda:0 f32[]"

  # /home/mkozuki/.pyenv/versions/3.10.13/envs/torchdev-3.10/lib/python3.10/site-packages/torchao/float8/float8_utils.py:48: 	    amax = amax.to(torch.float64)
  t5 = ltorch.to(amax, torch.float64, None, device=None, dtype=None, copy=False, memory_format=None)  # t5: "cuda:0 f64[]"
    # t5 = prims.convert_element_type(amax, dtypes.float64)  # t5: "cuda:0 f64[]"

  # /home/mkozuki/.pyenv/versions/3.10.13/envs/torchdev-3.10/lib/python3.10/site-packages/torchao/float8/float8_utils.py:50: 	        res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
  t10 = ltorch.clamp(t5, 1e-12, None)  # t10: "cuda:0 f64[]"
    # t7 = ltorch.ne(t5, t5)  # t7: "cuda:0 b8[]"
      # t7 = prims.ne(t5, t5)  # t7: "cuda:0 b8[]"
    # t8 = ltorch.gt(t5, 1e-12)  # t8: "cuda:0 b8[]"
      # t8 = prims.gt(t5, 1e-12)  # t8: "cuda:0 b8[]"
    # t9 = ltorch.where(t8, t5, 1e-12)  # t9: "cuda:0 f64[]"
      # t9 = prims.where(t8, t5, 1e-12)  # t9: "cuda:0 f64[]"
    # t10 = ltorch.where(t7, t5, t9)  # t10: "cuda:0 f64[]"
      # t10 = prims.where(t7, t5, t9)  # t10: "cuda:0 f64[]"
  res = ltorch.true_divide(448.0, t10)  # res: "cuda:0 f64[]"
    # res = prims.div(448.0, t10)  # res: "cuda:0 f64[]"

  # /home/mkozuki/.pyenv/versions/3.10.13/envs/torchdev-3.10/lib/python3.10/site-packages/torchao/float8/float8_utils.py:59: 	    return res.to(torch.float32)
  scale = ltorch.to(res, torch.float32, None, device=None, dtype=None, copy=False, memory_format=None)  # scale: "cuda:0 f32[]"
    # scale = prims.convert_element_type(res, dtypes.float32)  # scale: "cuda:0 f32[]"

  # /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/proxies.py:1964: 	                self.requires_grad,
  input_fp8 = _ToFloat8ConstrFunc_138480459367136_0(input, scale, torch.float8_e4m3fn, LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), _GemmInputRole_4, None)  # input_fp8: "cuda:0 f32[16, 32]"
    # t31 = ltorch.mul(input, scale)  # t31: "cuda:0 f32[16, 32]"
      # t30 = prims.broadcast_in_dim(scale, (16, 32), ())  # t30: "cuda:0 f32[16, 32]"
      # t31 = prims.mul(input, t30)  # t31: "cuda:0 f32[16, 32]"
    # t39 = ltorch.clamp(t31, -448.0, 448.0)  # t39: "cuda:0 f32[16, 32]"
      # t32 = ltorch.ne(t31, t31)  # t32: "cuda:0 b8[16, 32]"
        # t32 = prims.ne(t31, t31)  # t32: "cuda:0 b8[16, 32]"
      # t33 = ltorch.gt(t31, -448.0)  # t33: "cuda:0 b8[16, 32]"
        # t33 = prims.gt(t31, -448.0)  # t33: "cuda:0 b8[16, 32]"
      # t34 = ltorch.where(t33, t31, -448.0)  # t34: "cuda:0 f32[16, 32]"
        # t34 = prims.where(t33, t31, -448.0)  # t34: "cuda:0 f32[16, 32]"
      # t35 = ltorch.where(t32, t31, t34)  # t35: "cuda:0 f32[16, 32]"
        # t35 = prims.where(t32, t31, t34)  # t35: "cuda:0 f32[16, 32]"
      # t36 = ltorch.ne(t35, t35)  # t36: "cuda:0 b8[16, 32]"
        # t36 = prims.ne(t35, t35)  # t36: "cuda:0 b8[16, 32]"
      # t37 = ltorch.lt(t35, 448.0)  # t37: "cuda:0 b8[16, 32]"
        # t37 = prims.lt(t35, 448.0)  # t37: "cuda:0 b8[16, 32]"
      # t38 = ltorch.where(t37, t35, 448.0)  # t38: "cuda:0 f32[16, 32]"
        # t38 = prims.where(t37, t35, 448.0)  # t38: "cuda:0 f32[16, 32]"
      # t39 = ltorch.where(t36, t35, t38)  # t39: "cuda:0 f32[16, 32]"
        # t39 = prims.where(t36, t35, t38)  # t39: "cuda:0 f32[16, 32]"
    # t40 = ltorch.to(t39, torch.float8_e4m3fn, None, device=None, dtype=None, copy=False, memory_format=None)  # t40: "cuda:0 f8_e4m3fn[16, 32]"
      # t40 = prims.convert_element_type(t39, dtypes.float8_e4m3fn)  # t40: "cuda:0 f8_e4m3fn[16, 32]"
    # t41 = prims.tensor_subclass_ctor(_torch__C__TensorMeta_6, 't25', (16, 32), devices.Device("cuda:0"), dtypes.float32, False, [t40, scale], [dtypes.float32, LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), _GemmInputRole_4, None])  # t41: "cuda:0 f32[16, 32]"
    # input_fp8 = prims.shallow_copy(t41)  # input_fp8: "cuda:0 f32[16, 32]"

@crcrpar
Copy link
Collaborator Author

crcrpar commented Nov 24, 2024

as of 9d5d1aa

# Constructed by Transform for execution (took 7 milliseconds)
import thunder.core.devices as devices
import thunder.core.dtypes as dtypes
import torch
from torch import Tensor
from torchao.float8.float8_tensor import Float8Tensor
from torchao.float8.float8_tensor import ScaledMMConfig
from torchao.float8.float8_tensor import LinearMMConfig
from torchao.float8.float8_tensor import GemmInputRole
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def tmp__ToFloat8ConstrFunc_140114140809152_0(input, scale):
  t32 = torch.mul(input, scale)  # t32: "cuda:0 f32[16, 32]"
    # t32 = ltorch.mul(input, scale)  # t32: "cuda:0 f32[16, 32]"
      # t0 = prims.broadcast_in_dim(scale, (16, 32), ())  # t0: "cuda:0 f32[16, 32]"
      # t32 = prims.mul(input, t0)  # t32: "cuda:0 f32[16, 32]"
  t40 = torch.clamp(t32, -448.0, 448.0)  # t40: "cuda:0 f32[16, 32]"
    # t40 = ltorch.clamp(t32, -448.0, 448.0)  # t40: "cuda:0 f32[16, 32]"
      # t2 = ltorch.ne(t32, t32)  # t2: "cuda:0 b8[16, 32]"
        # t2 = prims.ne(t32, t32)  # t2: "cuda:0 b8[16, 32]"
      # t3 = ltorch.gt(t32, -448.0)  # t3: "cuda:0 b8[16, 32]"
        # t3 = prims.gt(t32, -448.0)  # t3: "cuda:0 b8[16, 32]"
      # t4 = ltorch.where(t3, t32, -448.0)  # t4: "cuda:0 f32[16, 32]"
        # t4 = prims.where(t3, t32, -448.0)  # t4: "cuda:0 f32[16, 32]"
      # t5 = ltorch.where(t2, t32, t4)  # t5: "cuda:0 f32[16, 32]"
        # t5 = prims.where(t2, t32, t4)  # t5: "cuda:0 f32[16, 32]"
      # t6 = ltorch.ne(t5, t5)  # t6: "cuda:0 b8[16, 32]"
        # t6 = prims.ne(t5, t5)  # t6: "cuda:0 b8[16, 32]"
      # t7 = ltorch.lt(t5, 448.0)  # t7: "cuda:0 b8[16, 32]"
        # t7 = prims.lt(t5, 448.0)  # t7: "cuda:0 b8[16, 32]"
      # t8 = ltorch.where(t7, t5, 448.0)  # t8: "cuda:0 f32[16, 32]"
        # t8 = prims.where(t7, t5, 448.0)  # t8: "cuda:0 f32[16, 32]"
      # t40 = ltorch.where(t6, t5, t8)  # t40: "cuda:0 f32[16, 32]"
        # t40 = prims.where(t6, t5, t8)  # t40: "cuda:0 f32[16, 32]"
  t41 = Tensor.to(t40, copy=False, dtype=torch.float8_e4m3fn)  # t41: "cuda:0 f8_e4m3fn[16, 32]"
    # t41 = ltorch.to(t40, None, None, device=None, dtype=torch.float8_e4m3fn, copy=False, memory_format=None)  # t41: "cuda:0 f8_e4m3fn[16, 32]"
      # t41 = prims.convert_element_type(t40, dtypes.float8_e4m3fn)  # t41: "cuda:0 f8_e4m3fn[16, 32]"
  input_fp8 = Float8Tensor(t41, scale, torch.float32, LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), _GemmInputRole_1, None)  # input_fp8: "cuda:0 f32[16, 32]"
  return input_fp8
thunder/transforms/tensor_subclasses.py:461: in __call__
    fx, sequencified_cosmeticized_out, orig_output, _ = self.convert_trace_to_fx_graph_and_get_fake_result(trace)
thunder/transforms/tensor_subclasses.py:436: in convert_trace_to_fx_graph_and_get_fake_result
    fx: GraphModule = make_fx(f_with_wrap_and_unwrap)(*desugared_args)
../../crcrpar/pytorch/torch/fx/experimental/proxy_tensor.py:2188: in wrapped
    return make_fx_tracer.trace(f, *args)
../../crcrpar/pytorch/torch/fx/experimental/proxy_tensor.py:2126: in trace
    return self._trace_inner(f, *args)
../../crcrpar/pytorch/torch/fx/experimental/proxy_tensor.py:2097: in _trace_inner
    t = dispatch_trace(
../../crcrpar/pytorch/torch/_compile.py:32: in inner
    return disable_fn(*args, **kwargs)
../../crcrpar/pytorch/torch/_dynamo/eval_frame.py:721: in _fn
    return fn(*args, **kwargs)
../../crcrpar/pytorch/torch/fx/experimental/proxy_tensor.py:1173: in dispatch_trace
    return fx._lazy_graph_module._make_graph_module(tracer.root, graph, name)
../../crcrpar/pytorch/torch/fx/_lazy_graph_module.py:61: in _make_graph_module
    return graph_module_cls(*args, **kwargs)
../../crcrpar/pytorch/torch/fx/graph_module.py:511: in __init__
    self.graph = graph
../../crcrpar/pytorch/torch/nn/modules/module.py:2036: in __setattr__
    super().__setattr__(name, value)
../../crcrpar/pytorch/torch/fx/graph_module.py:558: in graph
    self.recompile()
../../crcrpar/pytorch/torch/fx/graph_module.py:808: in recompile
    cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
../../crcrpar/pytorch/torch/fx/graph_module.py:92: in _forward_from_src
    return _method_from_src(
../../crcrpar/pytorch/torch/fx/graph_module.py:102: in _method_from_src
    _exec_with_source(src, globals_copy, co_fields)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

src = "\n\n\ndef forward(self, arg0_1, arg1_1):\n    mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1);  arg0_1 = None\n    cl..._dim': None});  _tensor_constant0 = arg1_1 = linear_mmconfig = None\n    return (output_wrapper_for_fx_tracing,)\n    "
globals = {'NoneType': <class 'NoneType'>, 'device': <class 'torch.device'>, 'fx_pytree': <module 'torch.fx._pytree' from '/home/mkozuki/ghq/github.com/crcrpar/pytorch/torch/fx/_pytree.py'>, 'inf': inf, ...}
co_fields = {'co_filename': '/home/mkozuki/ghq/github.com/crcrpar/pytorch/torch/fx/experimental/proxy_tensor.py', 'co_firstlineno': 1181, 'co_name': 'wrapped'}

    def _exec_with_source(src: str, globals: Dict[str, Any], co_fields=None):
        key = _loader.cache(src, globals, co_fields)
>       exec(compile(src, key, "exec"), globals)
E         File "<eval_with_key>.0 from /home/mkozuki/ghq/github.com/crcrpar/pytorch/torch/fx/experimental/proxy_tensor.py:1181 in wrapped", line 13
E           output_wrapper_for_fx_tracing = thunder_transforms_tensor_subclasses_OutputWrapperForFxTracing({'_data': _tensor_constant0, '_scale': arg1_1}, {'_orig_dtype': torch.float32, '_linear_mm_config': linear_mmconfig, '_gemm_input_role': <GemmInputRole.INPUT: 'input'>, '_axiswise_dim': None});  _tensor_constant0 = arg1_1 = linear_mmconfig = None
E                                                                                                                                                                                                                                                   ^
E       SyntaxError: invalid syntax

../../crcrpar/pytorch/torch/fx/graph_module.py:88: SyntaxError

@crcrpar crcrpar force-pushed the crpa/subclass-torchao_float8tensor branch from 9d5d1aa to bccf751 Compare November 24, 2024 08:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant