-
Notifications
You must be signed in to change notification settings - Fork 215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support power of 2 scaling factors in float8 training and use e4m3 everywhere #1670
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1670
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 21e8061 with merge base 8afd10e ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
4ff4aca
to
f2433b1
Compare
ecc23ae
to
a9fe17e
Compare
@@ -36,6 +36,7 @@ def hp_tensor_to_float8_dynamic( | |||
device_mesh=None, | |||
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, | |||
axiswise_dim: Optional[int] = None, | |||
power_of_2_scale: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note for reviewer: this param list is getting pretty long, and 4 of the 9 params can be derived from the Float8LinearConfig. Any thoughts on refactoring to pass in the Float8LinearConfig directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds reasonable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, I'll do that in a follow up so Less can begin scale testing after we merge this asap
067db27
to
c70ad60
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great! commented on suggested naming, any chance we could also check that this does not regress performance in torchtitan?
My shared devgpu has too much other usage currently to test with Llama3 8b, so I used a debug model which is much smaller (800M params). For rowwise with power of 2 scales, memory is flat but there is an ~8% regression in TPS. I'm wondering if this is because a small model like this has higher performance variance or if there is a real issue, need to look into it further. Llama3 model configs: n_layers=4, dim=4096, n_heads=16 Tested on 4 H100s. Row wise without power of 2 scales:
Row wise with power of 2 scales:
|
sounds like that's worth a follow-up two things to check I can think of:
|
torchao/float8/float8_utils.py
Outdated
|
||
if round_scales_to_power_of_2: | ||
# rounds down to the nearest power of 2. | ||
res = torch.exp2(torch.floor(torch.log2(res))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be the same as setting the mantissa to all-zeroes (maybe with some special handling for inf/nan), and can be implemented with bit shifting. Do you want to try to see if that resolves the regression?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
didn't test, but something like
for float32
res = res.view(torch.uint32_t)
res = (res >> 23) << 23
res = res.view(torch.float)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uint32 doesn't support bitshift ops apparently so I had to use int32. unit tests pass though and TPS regression is gone. will the sign bit affect anything? I did some manual tests in the interpreter and rounding seemed to work as expecting.
[rank0]:2025-02-05 16:11:30,663 - root - INFO - step: 1 loss: 8.2105 memory: 9.69GiB(10.20%) tps: 610 mfu: 0.33%
[rank0]:2025-02-05 16:11:30,663 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-05 16:11:30,896 - root - INFO - step: 2 loss: 9.2258 memory: 11.02GiB(11.60%) tps: 70,207 mfu: 37.73%
[rank0]:2025-02-05 16:11:31,129 - root - INFO - step: 3 loss: 8.5120 memory: 11.02GiB(11.60%) tps: 70,377 mfu: 37.82%
[rank0]:2025-02-05 16:11:31,361 - root - INFO - step: 4 loss: 11.7253 memory: 11.02GiB(11.60%) tps: 70,885 mfu: 38.10%
[rank0]:2025-02-05 16:11:31,591 - root - INFO - step: 5 loss: 9.3686 memory: 11.02GiB(11.60%) tps: 71,365 mfu: 38.35%
[rank0]:2025-02-05 16:11:31,823 - root - INFO - step: 6 loss: 8.5610 memory: 11.02GiB(11.60%) tps: 70,634 mfu: 37.96%
[rank0]:2025-02-05 16:11:32,059 - root - INFO - step: 7 loss: 7.7763 memory: 11.02GiB(11.60%) tps: 69,681 mfu: 37.45%
[rank0]:2025-02-05 16:11:32,287 - root - INFO - step: 8 loss: 7.4649 memory: 11.02GiB(11.60%) tps: 71,963 mfu: 38.68%
[rank0]:2025-02-05 16:11:32,517 - root - INFO - step: 9 loss: 7.2956 memory: 11.02GiB(11.60%) tps: 71,188 mfu: 38.26%
[rank0]:2025-02-05 16:11:32,749 - root - INFO - step: 10 loss: 7.1085 memory: 11.02GiB(11.60%) tps: 70,748 mfu: 38.02%```
else: | ||
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") | ||
|
||
return res.to(torch.float32) | ||
if round_scales_to_power_of_2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we're using bit shifting, IMO it would be good to
- wrap this into a function
- assert the input is float32
- add tests just around this function, testing that 0, positive finite number, infinity, nan are all handled correctly
it's ok as is, but numerical correctness is IMO a good place to be super explicit and eliminate potential confusion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, done. I also did another round of torchtitan benchmarks with the final implementation:
Float8 row wise without power of 2:
[rank0]:2025-02-06 12:13:42,245 - root - INFO - step: 1 loss: 12.2341 memory: 47.97GiB(50.49%) tps: 755 mfu: 4.42%
[rank0]:2025-02-06 12:13:42,246 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-06 12:13:54,758 - root - INFO - step: 10 loss: 10.0339 memory: 62.87GiB(66.17%) tps: 5,893 mfu: 34.51%
[rank0]:2025-02-06 12:14:08,326 - root - INFO - step: 20 loss: 8.4962 memory: 62.87GiB(66.17%) tps: 6,038 mfu: 35.36%
[rank0]:2025-02-06 12:14:21,886 - root - INFO - step: 30 loss: 7.6160 memory: 62.87GiB(66.17%) tps: 6,042 mfu: 35.38%
Float8 row wise with power of 2:
[rank0]:2025-02-06 12:10:54,300 - root - INFO - step: 1 loss: 12.2512 memory: 47.97GiB(50.49%) tps: 347 mfu: 2.03%
[rank0]:2025-02-06 12:10:54,301 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-06 12:11:06,505 - root - INFO - step: 10 loss: 10.1018 memory: 62.87GiB(66.17%) tps: 6,041 mfu: 35.38%
[rank0]:2025-02-06 12:11:20,063 - root - INFO - step: 20 loss: 8.6927 memory: 62.87GiB(66.17%) tps: 6,043 mfu: 35.39%
[rank0]:2025-02-06 12:11:33,621 - root - INFO - step: 30 loss: 7.6843 memory: 62.87GiB(66.17%) tps: 6,042 mfu: 35.38%
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is ready for another look when you have time - CI error on H100s is unrelated:
docker: Error response from daemon: failed to create task for container: failed to create shim task: OCI runtime create failed: runc create failed: unable to start container process: error during container init: error running prestart hook #0: exit status 1, stdout: , stderr: Auto-detected mode as 'legacy'
nvidia-container-cli: error parsing IMEX info: unsupported IMEX channel value: all: unknown.
I think it may be caused by using a legacy container image without certain IMEX env var set? NVIDIA/nvidia-container-toolkit#797
anyway, i'll try retriggering CI and also in the meantime i'll take a look at the triton kernels compile generates for the exp2(floor(log2(x)) approach and see if i can tell why it's slow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i just looked into why exp2(floor(log2(x))) is slow actually and it's actually an easy fix. when we do the rounding on this line, the scale is still in fp64:
ao/torchao/float8/float8_utils.py
Line 49 in 1d75c8f
if we convert to fp32 before doing the rounding, instead of at the end when we return (
ao/torchao/float8/float8_utils.py
Line 50 in 1d75c8f
return res.to(torch.float32) |
Maybe simply because with double the bit-width these rounding ops are slower (we can only achieve 50% max TFLOPs in FP64 as FP32 on H100)? I'm surprised the effect is pronounced enough to cause a 8% regression in overall TPS when rounding fp64 scales, though. I haven't looked into the generated triton kernels yet, prioritizing shipping this first.
Benchmark data:
When scale is still float64 when rounding:
[rank0]:2025-02-06 13:24:25,812 - root - INFO - step: 1 loss: 12.2439 memory: 47.97GiB(50.49%) tps: 863 mfu: 5.06%
[rank0]:2025-02-06 13:24:25,812 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-06 13:24:38,548 - root - INFO - step: 10 loss: 9.9485 memory: 62.87GiB(66.17%) tps: 5,789 mfu: 33.90%
[rank0]:2025-02-06 13:24:52,685 - root - INFO - step: 20 loss: 8.4416 memory: 62.87GiB(66.17%) tps: 5,795 mfu: 33.93%
[rank0]:2025-02-06 13:25:06,827 - root - INFO - step: 30 loss: 7.6019 memory: 62.87GiB(66.17%) tps: 5,793 mfu: 33.92%
[rank0]:2025-02-06 13:25:20,968 - root - INFO - step: 40 loss: 7.4452 memory: 62.87GiB(66.17%) tps: 5,793 mfu: 33.93%
When scale is converted to fp32 before rounding:
[rank0]:2025-02-06 13:22:44,780 - root - INFO - step: 1 loss: 12.2436 memory: 47.97GiB(50.49%) tps: 859 mfu: 5.03%
[rank0]:2025-02-06 13:22:44,781 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-06 13:22:56,675 - root - INFO - step: 10 loss: 9.9731 memory: 62.87GiB(66.17%) tps: 6,199 mfu: 36.30%
[rank0]:2025-02-06 13:23:09,869 - root - INFO - step: 20 loss: 8.5158 memory: 62.87GiB(66.17%) tps: 6,209 mfu: 36.36%
[rank0]:2025-02-06 13:23:23,060 - root - INFO - step: 30 loss: 7.5902 memory: 62.87GiB(66.17%) tps: 6,211 mfu: 36.37%
[rank0]:2025-02-06 13:23:36,270 - root - INFO - step: 40 loss: 7.3799 memory: 62.87GiB(66.17%) tps: 6,202 mfu: 36.32%
b7f48ab
to
36b8ac6
Compare
03256a6
to
e9d4ce5
Compare
e9d4ce5
to
4169927
Compare
11d2b45
to
7f25616
Compare
7f25616
to
40166e1
Compare
looks good! it would also be great to verify loss curves still match bfloat16 on LLaMa 3 8B via torchtitan - I think they should, but good to verify for any numerics related change |
Hmm, I can't seem to reproduce the CI issue locally. I copy pasted the env setup commands from CI container logs and ran the test: conda create -n venv python=3.9 -y
conda activate venv
python -m pip install --upgrade pip
pip install torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121
pip install -r dev-requirements.txt
pip install .
pytest test/float8/test_float8_utils.py
============================================== test session starts ==============================================
platform linux -- Python 3.9.21, pytest-7.4.0, pluggy-1.5.0
rootdir: /data/users/danvm/ao
plugins: hypothesis-6.125.2
collected 17 items
test/float8/test_float8_utils.py ................. [100%]
=============================================== warnings summary ================================================
test/float8/test_float8_utils.py::test_non_float32_input[invalid_dtype3]
test/float8/test_float8_utils.py::test_non_float32_input[invalid_dtype4]
test/float8/test_float8_utils.py::test_non_float32_input[invalid_dtype5]
test/float8/test_float8_utils.py::test_non_float32_input[invalid_dtype6]
test/float8/test_float8_utils.py::test_non_float32_input[invalid_dtype7]
/data/users/danvm/ao/test/float8/test_float8_utils.py:59: DeprecationWarning: an integer is required (got type float). Implicit conversion to integers using __int__ is deprecated, and may be removed in a future version of Python.
non_float32_tensor = torch.tensor([3.0], dtype=invalid_dtype)
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================== 17 passed, 5 warnings in 2.43s ========================================= |
4baed04
to
533e027
Compare
9cf79ab
to
69dbadb
Compare
@vkuzo I think to unblock Less, we should comment out the test case and merge this, since the issue is not related to the rounding function. Then we can continue investigating the root cause. To clarify what the exact issue is: just simply creating tensor with the fp32 largest subnormal value in the problematic CI env truncates it to 0, before it can even passed into the rounding function. As mentioned in #1670 (comment) the issue does not reproduce locally even when recreating the CI env as closely as possible, so I'm still investigating. |
@vkuzo Updated this PR to include using e4m3 dtype for all inputs/weights/grad outputs for float8 rowwise recipe. Confirmed training accuracy is not negatively impacted by comparing to bf16 eager baseline. (ran bf16 for 330 steps and fp8 for 110 steps but you can see the loss curves are virtually identical). cc @lessw2020 |
I think this is safe to land, thanks for adding this! |
Summary
Add support for power of 2 scaling factors in float8 training with dynamic scaling.
Behavior:
Float8LinearConfig
returned fromrecipe_name_to_linear_config
for rowwise scaling.Test Plan
Updated test cases to ensure power of 2 scaling does not impact numerics for axiswise dynamic scaling (eager and compiled)