-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torchao float8tensor] #1415
Draft
crcrpar
wants to merge
15
commits into
crpa/subclass-tensor-ops
Choose a base branch
from
crpa/subclass-torchao_float8tensor
base: crpa/subclass-tensor-ops
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
[torchao float8tensor] #1415
crcrpar
wants to merge
15
commits into
crpa/subclass-tensor-ops
from
crpa/subclass-torchao_float8tensor
+378
−39
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
As of abf0167
|
crcrpar
force-pushed
the
crpa/subclass-tensor-ops
branch
from
November 19, 2024 06:41
3fa8e2d
to
d5fb9fe
Compare
Signed-off-by: Masaki Kozuki <[email protected]>
next, function with tensor creation in it Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]> revert wrong patch Signed-off-by: Masaki Kozuki <[email protected]> supply unpacks with traces generated within the lookaside Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
crcrpar
force-pushed
the
crpa/subclass-torchao_float8tensor
branch
from
November 21, 2024 03:14
abf0167
to
e7ca8b7
Compare
Signed-off-by: Masaki Kozuki <[email protected]>
This is a portion of the initial trace as of dd075db: # Constructed by Remove context manager prims
import thunder
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(input, weight):
# input: "cuda:0 f32[16, 32]"
# weight: "cuda:0 f32[64, 32]"
# /home/mkozuki/.pyenv/versions/3.10.13/envs/torchdev-3.10/lib/python3.10/site-packages/torchao/float8/float8_utils.py:113: amax = torch.max(torch.abs(x))
t3 = ltorch.abs(input) # t3: "cuda:0 f32[16, 32]"
# t3 = prims.abs(input) # t3: "cuda:0 f32[16, 32]"
amax = ltorch.torch_max(t3, None, False) # amax: "cuda:0 f32[]"
# amax = ltorch.amax(t3, [0, 1], False) # amax: "cuda:0 f32[]"
# t3 = ltorch.to(t3, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t3: "cuda:0 f32[16, 32]"
# amax = prims.amax(t3, (0, 1)) # amax: "cuda:0 f32[]"
# amax = ltorch.to(amax, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # amax: "cuda:0 f32[]"
# /home/mkozuki/.pyenv/versions/3.10.13/envs/torchdev-3.10/lib/python3.10/site-packages/torchao/float8/float8_utils.py:48: amax = amax.to(torch.float64)
t5 = ltorch.to(amax, torch.float64, None, device=None, dtype=None, copy=False, memory_format=None) # t5: "cuda:0 f64[]"
# t5 = prims.convert_element_type(amax, dtypes.float64) # t5: "cuda:0 f64[]"
# /home/mkozuki/.pyenv/versions/3.10.13/envs/torchdev-3.10/lib/python3.10/site-packages/torchao/float8/float8_utils.py:50: res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
t10 = ltorch.clamp(t5, 1e-12, None) # t10: "cuda:0 f64[]"
# t7 = ltorch.ne(t5, t5) # t7: "cuda:0 b8[]"
# t7 = prims.ne(t5, t5) # t7: "cuda:0 b8[]"
# t8 = ltorch.gt(t5, 1e-12) # t8: "cuda:0 b8[]"
# t8 = prims.gt(t5, 1e-12) # t8: "cuda:0 b8[]"
# t9 = ltorch.where(t8, t5, 1e-12) # t9: "cuda:0 f64[]"
# t9 = prims.where(t8, t5, 1e-12) # t9: "cuda:0 f64[]"
# t10 = ltorch.where(t7, t5, t9) # t10: "cuda:0 f64[]"
# t10 = prims.where(t7, t5, t9) # t10: "cuda:0 f64[]"
res = ltorch.true_divide(448.0, t10) # res: "cuda:0 f64[]"
# res = prims.div(448.0, t10) # res: "cuda:0 f64[]"
# /home/mkozuki/.pyenv/versions/3.10.13/envs/torchdev-3.10/lib/python3.10/site-packages/torchao/float8/float8_utils.py:59: return res.to(torch.float32)
scale = ltorch.to(res, torch.float32, None, device=None, dtype=None, copy=False, memory_format=None) # scale: "cuda:0 f32[]"
# scale = prims.convert_element_type(res, dtypes.float32) # scale: "cuda:0 f32[]"
# /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/proxies.py:1964: self.requires_grad,
input_fp8 = _ToFloat8ConstrFunc_138480459367136_0(input, scale, torch.float8_e4m3fn, LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), _GemmInputRole_4, None) # input_fp8: "cuda:0 f32[16, 32]"
# t31 = ltorch.mul(input, scale) # t31: "cuda:0 f32[16, 32]"
# t30 = prims.broadcast_in_dim(scale, (16, 32), ()) # t30: "cuda:0 f32[16, 32]"
# t31 = prims.mul(input, t30) # t31: "cuda:0 f32[16, 32]"
# t39 = ltorch.clamp(t31, -448.0, 448.0) # t39: "cuda:0 f32[16, 32]"
# t32 = ltorch.ne(t31, t31) # t32: "cuda:0 b8[16, 32]"
# t32 = prims.ne(t31, t31) # t32: "cuda:0 b8[16, 32]"
# t33 = ltorch.gt(t31, -448.0) # t33: "cuda:0 b8[16, 32]"
# t33 = prims.gt(t31, -448.0) # t33: "cuda:0 b8[16, 32]"
# t34 = ltorch.where(t33, t31, -448.0) # t34: "cuda:0 f32[16, 32]"
# t34 = prims.where(t33, t31, -448.0) # t34: "cuda:0 f32[16, 32]"
# t35 = ltorch.where(t32, t31, t34) # t35: "cuda:0 f32[16, 32]"
# t35 = prims.where(t32, t31, t34) # t35: "cuda:0 f32[16, 32]"
# t36 = ltorch.ne(t35, t35) # t36: "cuda:0 b8[16, 32]"
# t36 = prims.ne(t35, t35) # t36: "cuda:0 b8[16, 32]"
# t37 = ltorch.lt(t35, 448.0) # t37: "cuda:0 b8[16, 32]"
# t37 = prims.lt(t35, 448.0) # t37: "cuda:0 b8[16, 32]"
# t38 = ltorch.where(t37, t35, 448.0) # t38: "cuda:0 f32[16, 32]"
# t38 = prims.where(t37, t35, 448.0) # t38: "cuda:0 f32[16, 32]"
# t39 = ltorch.where(t36, t35, t38) # t39: "cuda:0 f32[16, 32]"
# t39 = prims.where(t36, t35, t38) # t39: "cuda:0 f32[16, 32]"
# t40 = ltorch.to(t39, torch.float8_e4m3fn, None, device=None, dtype=None, copy=False, memory_format=None) # t40: "cuda:0 f8_e4m3fn[16, 32]"
# t40 = prims.convert_element_type(t39, dtypes.float8_e4m3fn) # t40: "cuda:0 f8_e4m3fn[16, 32]"
# t41 = prims.tensor_subclass_ctor(_torch__C__TensorMeta_6, 't25', (16, 32), devices.Device("cuda:0"), dtypes.float32, False, [t40, scale], [dtypes.float32, LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), _GemmInputRole_4, None]) # t41: "cuda:0 f32[16, 32]"
# input_fp8 = prims.shallow_copy(t41) # input_fp8: "cuda:0 f32[16, 32]" |
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
as of 9d5d1aa # Constructed by Transform for execution (took 7 milliseconds)
import thunder.core.devices as devices
import thunder.core.dtypes as dtypes
import torch
from torch import Tensor
from torchao.float8.float8_tensor import Float8Tensor
from torchao.float8.float8_tensor import ScaledMMConfig
from torchao.float8.float8_tensor import LinearMMConfig
from torchao.float8.float8_tensor import GemmInputRole
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def tmp__ToFloat8ConstrFunc_140114140809152_0(input, scale):
t32 = torch.mul(input, scale) # t32: "cuda:0 f32[16, 32]"
# t32 = ltorch.mul(input, scale) # t32: "cuda:0 f32[16, 32]"
# t0 = prims.broadcast_in_dim(scale, (16, 32), ()) # t0: "cuda:0 f32[16, 32]"
# t32 = prims.mul(input, t0) # t32: "cuda:0 f32[16, 32]"
t40 = torch.clamp(t32, -448.0, 448.0) # t40: "cuda:0 f32[16, 32]"
# t40 = ltorch.clamp(t32, -448.0, 448.0) # t40: "cuda:0 f32[16, 32]"
# t2 = ltorch.ne(t32, t32) # t2: "cuda:0 b8[16, 32]"
# t2 = prims.ne(t32, t32) # t2: "cuda:0 b8[16, 32]"
# t3 = ltorch.gt(t32, -448.0) # t3: "cuda:0 b8[16, 32]"
# t3 = prims.gt(t32, -448.0) # t3: "cuda:0 b8[16, 32]"
# t4 = ltorch.where(t3, t32, -448.0) # t4: "cuda:0 f32[16, 32]"
# t4 = prims.where(t3, t32, -448.0) # t4: "cuda:0 f32[16, 32]"
# t5 = ltorch.where(t2, t32, t4) # t5: "cuda:0 f32[16, 32]"
# t5 = prims.where(t2, t32, t4) # t5: "cuda:0 f32[16, 32]"
# t6 = ltorch.ne(t5, t5) # t6: "cuda:0 b8[16, 32]"
# t6 = prims.ne(t5, t5) # t6: "cuda:0 b8[16, 32]"
# t7 = ltorch.lt(t5, 448.0) # t7: "cuda:0 b8[16, 32]"
# t7 = prims.lt(t5, 448.0) # t7: "cuda:0 b8[16, 32]"
# t8 = ltorch.where(t7, t5, 448.0) # t8: "cuda:0 f32[16, 32]"
# t8 = prims.where(t7, t5, 448.0) # t8: "cuda:0 f32[16, 32]"
# t40 = ltorch.where(t6, t5, t8) # t40: "cuda:0 f32[16, 32]"
# t40 = prims.where(t6, t5, t8) # t40: "cuda:0 f32[16, 32]"
t41 = Tensor.to(t40, copy=False, dtype=torch.float8_e4m3fn) # t41: "cuda:0 f8_e4m3fn[16, 32]"
# t41 = ltorch.to(t40, None, None, device=None, dtype=torch.float8_e4m3fn, copy=False, memory_format=None) # t41: "cuda:0 f8_e4m3fn[16, 32]"
# t41 = prims.convert_element_type(t40, dtypes.float8_e4m3fn) # t41: "cuda:0 f8_e4m3fn[16, 32]"
input_fp8 = Float8Tensor(t41, scale, torch.float32, LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), _GemmInputRole_1, None) # input_fp8: "cuda:0 f32[16, 32]"
return input_fp8 thunder/transforms/tensor_subclasses.py:461: in __call__
fx, sequencified_cosmeticized_out, orig_output, _ = self.convert_trace_to_fx_graph_and_get_fake_result(trace)
thunder/transforms/tensor_subclasses.py:436: in convert_trace_to_fx_graph_and_get_fake_result
fx: GraphModule = make_fx(f_with_wrap_and_unwrap)(*desugared_args)
../../crcrpar/pytorch/torch/fx/experimental/proxy_tensor.py:2188: in wrapped
return make_fx_tracer.trace(f, *args)
../../crcrpar/pytorch/torch/fx/experimental/proxy_tensor.py:2126: in trace
return self._trace_inner(f, *args)
../../crcrpar/pytorch/torch/fx/experimental/proxy_tensor.py:2097: in _trace_inner
t = dispatch_trace(
../../crcrpar/pytorch/torch/_compile.py:32: in inner
return disable_fn(*args, **kwargs)
../../crcrpar/pytorch/torch/_dynamo/eval_frame.py:721: in _fn
return fn(*args, **kwargs)
../../crcrpar/pytorch/torch/fx/experimental/proxy_tensor.py:1173: in dispatch_trace
return fx._lazy_graph_module._make_graph_module(tracer.root, graph, name)
../../crcrpar/pytorch/torch/fx/_lazy_graph_module.py:61: in _make_graph_module
return graph_module_cls(*args, **kwargs)
../../crcrpar/pytorch/torch/fx/graph_module.py:511: in __init__
self.graph = graph
../../crcrpar/pytorch/torch/nn/modules/module.py:2036: in __setattr__
super().__setattr__(name, value)
../../crcrpar/pytorch/torch/fx/graph_module.py:558: in graph
self.recompile()
../../crcrpar/pytorch/torch/fx/graph_module.py:808: in recompile
cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
../../crcrpar/pytorch/torch/fx/graph_module.py:92: in _forward_from_src
return _method_from_src(
../../crcrpar/pytorch/torch/fx/graph_module.py:102: in _method_from_src
_exec_with_source(src, globals_copy, co_fields)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src = "\n\n\ndef forward(self, arg0_1, arg1_1):\n mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = None\n cl..._dim': None}); _tensor_constant0 = arg1_1 = linear_mmconfig = None\n return (output_wrapper_for_fx_tracing,)\n "
globals = {'NoneType': <class 'NoneType'>, 'device': <class 'torch.device'>, 'fx_pytree': <module 'torch.fx._pytree' from '/home/mkozuki/ghq/github.com/crcrpar/pytorch/torch/fx/_pytree.py'>, 'inf': inf, ...}
co_fields = {'co_filename': '/home/mkozuki/ghq/github.com/crcrpar/pytorch/torch/fx/experimental/proxy_tensor.py', 'co_firstlineno': 1181, 'co_name': 'wrapped'}
def _exec_with_source(src: str, globals: Dict[str, Any], co_fields=None):
key = _loader.cache(src, globals, co_fields)
> exec(compile(src, key, "exec"), globals)
E File "<eval_with_key>.0 from /home/mkozuki/ghq/github.com/crcrpar/pytorch/torch/fx/experimental/proxy_tensor.py:1181 in wrapped", line 13
E output_wrapper_for_fx_tracing = thunder_transforms_tensor_subclasses_OutputWrapperForFxTracing({'_data': _tensor_constant0, '_scale': arg1_1}, {'_orig_dtype': torch.float32, '_linear_mm_config': linear_mmconfig, '_gemm_input_role': <GemmInputRole.INPUT: 'input'>, '_axiswise_dim': None}); _tensor_constant0 = arg1_1 = linear_mmconfig = None
E ^
E SyntaxError: invalid syntax
../../crcrpar/pytorch/torch/fx/graph_module.py:88: SyntaxError |
Signed-off-by: Masaki Kozuki <[email protected]>
crcrpar
force-pushed
the
crpa/subclass-torchao_float8tensor
branch
from
November 24, 2024 08:15
9d5d1aa
to
bccf751
Compare
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Before submitting
What does this PR do?
Fixes # (issue).
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃