Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support power of 2 scaling factors in float8 training and use e4m3 everywhere #1670

Merged
merged 16 commits into from
Feb 10, 2025

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Feb 5, 2025

Summary

Add support for power of 2 scaling factors in float8 training with dynamic scaling.

Behavior:

  • Default on in the rowwise scaling recipe in Float8LinearConfig returned from recipe_name_to_linear_config for rowwise scaling.
  • Default off for other cases.

Test Plan
Updated test cases to ensure power of 2 scaling does not impact numerics for axiswise dynamic scaling (eager and compiled)

Copy link

pytorch-bot bot commented Feb 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1670

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 21e8061 with merge base 8afd10e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 5, 2025
@danielvegamyhre danielvegamyhre changed the title Support power of 2 scaling factors in float8 training Support power of 2 scaling factors in float8 training via boolean param in Float8LinearConfig Feb 5, 2025
@danielvegamyhre danielvegamyhre marked this pull request as draft February 5, 2025 20:03
@danielvegamyhre danielvegamyhre added the topic: new feature Use this tag if this PR adds a new feature label Feb 5, 2025
@danielvegamyhre danielvegamyhre changed the title Support power of 2 scaling factors in float8 training via boolean param in Float8LinearConfig Support power of 2 scaling factors in float8 training Feb 5, 2025
@danielvegamyhre danielvegamyhre marked this pull request as ready for review February 5, 2025 21:37
@@ -36,6 +36,7 @@ def hp_tensor_to_float8_dynamic(
device_mesh=None,
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
axiswise_dim: Optional[int] = None,
power_of_2_scale: bool = False,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for reviewer: this param list is getting pretty long, and 4 of the 9 params can be derived from the Float8LinearConfig. Any thoughts on refactoring to pass in the Float8LinearConfig directly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds reasonable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, I'll do that in a follow up so Less can begin scale testing after we merge this asap

@danielvegamyhre danielvegamyhre requested a review from vkuzo February 5, 2025 21:47
torchao/float8/config.py Outdated Show resolved Hide resolved
Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great! commented on suggested naming, any chance we could also check that this does not regress performance in torchtitan?

@danielvegamyhre
Copy link
Contributor Author

danielvegamyhre commented Feb 5, 2025

looks great! commented on suggested naming, any chance we could also check that this does not regress performance in torchtitan?

My shared devgpu has too much other usage currently to test with Llama3 8b, so I used a debug model which is much smaller (800M params).

For rowwise with power of 2 scales, memory is flat but there is an ~8% regression in TPS. I'm wondering if this is because a small model like this has higher performance variance or if there is a real issue, need to look into it further.

Llama3 model configs: n_layers=4, dim=4096, n_heads=16

Tested on 4 H100s.

Row wise without power of 2 scales:

[rank0]:2025-02-05 15:05:21,682 - root - INFO - step:  1  loss:  8.2016  memory: 10.28GiB(10.82%)  tps: 584  mfu: 0.31%
[rank0]:2025-02-05 15:05:21,682 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-05 15:05:21,914 - root - INFO - step:  2  loss:  9.2027  memory: 11.67GiB(12.28%)  tps: 70,711  mfu: 38.00%
[rank0]:2025-02-05 15:05:22,146 - root - INFO - step:  3  loss:  8.4319  memory: 11.67GiB(12.28%)  tps: 70,914  mfu: 38.11%
[rank0]:2025-02-05 15:05:22,375 - root - INFO - step:  4  loss: 13.0116  memory: 11.67GiB(12.28%)  tps: 71,446  mfu: 38.40%
[rank0]:2025-02-05 15:05:22,604 - root - INFO - step:  5  loss: 10.0891  memory: 11.67GiB(12.28%)  tps: 71,662  mfu: 38.51%
[rank0]:2025-02-05 15:05:22,835 - root - INFO - step:  6  loss:  8.8140  memory: 11.67GiB(12.28%)  tps: 71,041  mfu: 38.18%
[rank0]:2025-02-05 15:05:23,068 - root - INFO - step:  7  loss:  7.9921  memory: 11.67GiB(12.28%)  tps: 70,531  mfu: 37.91%
[rank0]:2025-02-05 15:05:23,297 - root - INFO - step:  8  loss:  7.5519  memory: 11.67GiB(12.28%)  tps: 71,670  mfu: 38.52%
[rank0]:2025-02-05 15:05:23,525 - root - INFO - step:  9  loss:  7.4012  memory: 11.67GiB(12.28%)  tps: 71,808  mfu: 38.59%
[rank0]:2025-02-05 15:05:23,754 - root - INFO - step: 10  loss:  7.2013  memory: 11.67GiB(12.28%)  tps: 71,647  mfu: 38.51%

Row wise with power of 2 scales:

[rank0]:2025-02-05 15:02:52,539 - root - INFO - step:  1  loss:  8.2104  memory:  9.85GiB(10.37%)  tps: 1,981  mfu: 1.06%
[rank0]:2025-02-05 15:02:52,539 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-05 15:02:52,792 - root - INFO - step:  2  loss:  9.2376  memory: 11.20GiB(11.79%)  tps: 64,845  mfu: 34.85%
[rank0]:2025-02-05 15:02:53,046 - root - INFO - step:  3  loss:  8.6284  memory: 11.20GiB(11.79%)  tps: 64,742  mfu: 34.80%
[rank0]:2025-02-05 15:02:53,297 - root - INFO - step:  4  loss: 11.2887  memory: 11.20GiB(11.79%)  tps: 65,266  mfu: 35.08%
[rank0]:2025-02-05 15:02:53,548 - root - INFO - step:  5  loss:  9.4400  memory: 11.20GiB(11.79%)  tps: 65,429  mfu: 35.16%
[rank0]:2025-02-05 15:02:53,800 - root - INFO - step:  6  loss:  8.5271  memory: 11.20GiB(11.79%)  tps: 65,117  mfu: 35.00%
[rank0]:2025-02-05 15:02:54,055 - root - INFO - step:  7  loss:  7.8088  memory: 11.20GiB(11.79%)  tps: 64,426  mfu: 34.63%
[rank0]:2025-02-05 15:02:54,305 - root - INFO - step:  8  loss:  7.4392  memory: 11.20GiB(11.79%)  tps: 65,452  mfu: 35.18%
[rank0]:2025-02-05 15:02:54,555 - root - INFO - step:  9  loss:  7.3227  memory: 11.20GiB(11.79%)  tps: 65,783  mfu: 35.35%
[rank0]:2025-02-05 15:02:54,805 - root - INFO - step: 10  loss:  7.0642  memory: 11.20GiB(11.79%)  tps: 65,620  mfu: 35.27%

@vkuzo
Copy link
Contributor

vkuzo commented Feb 5, 2025

sounds like that's worth a follow-up

two things to check I can think of:

  1. does it reproduce on full size LLaMa 3 8B on 8 H100s?
  2. does the regression go away if we use bit shifting instead of exp and log?


if round_scales_to_power_of_2:
# rounds down to the nearest power of 2.
res = torch.exp2(torch.floor(torch.log2(res)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be the same as setting the mantissa to all-zeroes (maybe with some special handling for inf/nan), and can be implemented with bit shifting. Do you want to try to see if that resolves the regression?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't test, but something like

for float32

res = res.view(torch.uint32_t)
res = (res >> 23) << 23
res = res.view(torch.float)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uint32 doesn't support bitshift ops apparently so I had to use int32. unit tests pass though and TPS regression is gone. will the sign bit affect anything? I did some manual tests in the interpreter and rounding seemed to work as expecting.

[rank0]:2025-02-05 16:11:30,663 - root - INFO - step:  1  loss:  8.2105  memory:  9.69GiB(10.20%)  tps: 610  mfu: 0.33%
[rank0]:2025-02-05 16:11:30,663 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-05 16:11:30,896 - root - INFO - step:  2  loss:  9.2258  memory: 11.02GiB(11.60%)  tps: 70,207  mfu: 37.73%
[rank0]:2025-02-05 16:11:31,129 - root - INFO - step:  3  loss:  8.5120  memory: 11.02GiB(11.60%)  tps: 70,377  mfu: 37.82%
[rank0]:2025-02-05 16:11:31,361 - root - INFO - step:  4  loss: 11.7253  memory: 11.02GiB(11.60%)  tps: 70,885  mfu: 38.10%
[rank0]:2025-02-05 16:11:31,591 - root - INFO - step:  5  loss:  9.3686  memory: 11.02GiB(11.60%)  tps: 71,365  mfu: 38.35%
[rank0]:2025-02-05 16:11:31,823 - root - INFO - step:  6  loss:  8.5610  memory: 11.02GiB(11.60%)  tps: 70,634  mfu: 37.96%
[rank0]:2025-02-05 16:11:32,059 - root - INFO - step:  7  loss:  7.7763  memory: 11.02GiB(11.60%)  tps: 69,681  mfu: 37.45%
[rank0]:2025-02-05 16:11:32,287 - root - INFO - step:  8  loss:  7.4649  memory: 11.02GiB(11.60%)  tps: 71,963  mfu: 38.68%
[rank0]:2025-02-05 16:11:32,517 - root - INFO - step:  9  loss:  7.2956  memory: 11.02GiB(11.60%)  tps: 71,188  mfu: 38.26%
[rank0]:2025-02-05 16:11:32,749 - root - INFO - step: 10  loss:  7.1085  memory: 11.02GiB(11.60%)  tps: 70,748  mfu: 38.02%```

else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")

return res.to(torch.float32)
if round_scales_to_power_of_2:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we're using bit shifting, IMO it would be good to

  1. wrap this into a function
  2. assert the input is float32
  3. add tests just around this function, testing that 0, positive finite number, infinity, nan are all handled correctly

it's ok as is, but numerical correctness is IMO a good place to be super explicit and eliminate potential confusion

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Feb 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, done. I also did another round of torchtitan benchmarks with the final implementation:

Float8 row wise without power of 2:

[rank0]:2025-02-06 12:13:42,245 - root - INFO - step:  1  loss: 12.2341  memory: 47.97GiB(50.49%)  tps: 755  mfu: 4.42%
[rank0]:2025-02-06 12:13:42,246 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-06 12:13:54,758 - root - INFO - step: 10  loss: 10.0339  memory: 62.87GiB(66.17%)  tps: 5,893  mfu: 34.51%
[rank0]:2025-02-06 12:14:08,326 - root - INFO - step: 20  loss:  8.4962  memory: 62.87GiB(66.17%)  tps: 6,038  mfu: 35.36%
[rank0]:2025-02-06 12:14:21,886 - root - INFO - step: 30  loss:  7.6160  memory: 62.87GiB(66.17%)  tps: 6,042  mfu: 35.38%

Float8 row wise with power of 2:

[rank0]:2025-02-06 12:10:54,300 - root - INFO - step:  1  loss: 12.2512  memory: 47.97GiB(50.49%)  tps: 347  mfu: 2.03%
[rank0]:2025-02-06 12:10:54,301 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-06 12:11:06,505 - root - INFO - step: 10  loss: 10.1018  memory: 62.87GiB(66.17%)  tps: 6,041  mfu: 35.38%
[rank0]:2025-02-06 12:11:20,063 - root - INFO - step: 20  loss:  8.6927  memory: 62.87GiB(66.17%)  tps: 6,043  mfu: 35.39%
[rank0]:2025-02-06 12:11:33,621 - root - INFO - step: 30  loss:  7.6843  memory: 62.87GiB(66.17%)  tps: 6,042  mfu: 35.38%

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is ready for another look when you have time - CI error on H100s is unrelated:

docker: Error response from daemon: failed to create task for container: failed to create shim task: OCI runtime create failed: runc create failed: unable to start container process: error during container init: error running prestart hook #0: exit status 1, stdout: , stderr: Auto-detected mode as 'legacy'
nvidia-container-cli: error parsing IMEX info: unsupported IMEX channel value: all: unknown.

I think it may be caused by using a legacy container image without certain IMEX env var set? NVIDIA/nvidia-container-toolkit#797

anyway, i'll try retriggering CI and also in the meantime i'll take a look at the triton kernels compile generates for the exp2(floor(log2(x)) approach and see if i can tell why it's slow

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Feb 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i just looked into why exp2(floor(log2(x))) is slow actually and it's actually an easy fix. when we do the rounding on this line, the scale is still in fp64:

this causes it to be slow for some reason.

if we convert to fp32 before doing the rounding, instead of at the end when we return (

return res.to(torch.float32)
), this eliminates the TPS regression.

Maybe simply because with double the bit-width these rounding ops are slower (we can only achieve 50% max TFLOPs in FP64 as FP32 on H100)? I'm surprised the effect is pronounced enough to cause a 8% regression in overall TPS when rounding fp64 scales, though. I haven't looked into the generated triton kernels yet, prioritizing shipping this first.

Benchmark data:

When scale is still float64 when rounding:

[rank0]:2025-02-06 13:24:25,812 - root - INFO - step:  1  loss: 12.2439  memory: 47.97GiB(50.49%)  tps: 863  mfu: 5.06%
[rank0]:2025-02-06 13:24:25,812 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-06 13:24:38,548 - root - INFO - step: 10  loss:  9.9485  memory: 62.87GiB(66.17%)  tps: 5,789  mfu: 33.90%
[rank0]:2025-02-06 13:24:52,685 - root - INFO - step: 20  loss:  8.4416  memory: 62.87GiB(66.17%)  tps: 5,795  mfu: 33.93%
[rank0]:2025-02-06 13:25:06,827 - root - INFO - step: 30  loss:  7.6019  memory: 62.87GiB(66.17%)  tps: 5,793  mfu: 33.92%
[rank0]:2025-02-06 13:25:20,968 - root - INFO - step: 40  loss:  7.4452  memory: 62.87GiB(66.17%)  tps: 5,793  mfu: 33.93%

When scale is converted to fp32 before rounding:

[rank0]:2025-02-06 13:22:44,780 - root - INFO - step:  1  loss: 12.2436  memory: 47.97GiB(50.49%)  tps: 859  mfu: 5.03%
[rank0]:2025-02-06 13:22:44,781 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-06 13:22:56,675 - root - INFO - step: 10  loss:  9.9731  memory: 62.87GiB(66.17%)  tps: 6,199  mfu: 36.30%
[rank0]:2025-02-06 13:23:09,869 - root - INFO - step: 20  loss:  8.5158  memory: 62.87GiB(66.17%)  tps: 6,209  mfu: 36.36%
[rank0]:2025-02-06 13:23:23,060 - root - INFO - step: 30  loss:  7.5902  memory: 62.87GiB(66.17%)  tps: 6,211  mfu: 36.37%
[rank0]:2025-02-06 13:23:36,270 - root - INFO - step: 40  loss:  7.3799  memory: 62.87GiB(66.17%)  tps: 6,202  mfu: 36.32%

@danielvegamyhre danielvegamyhre force-pushed the powerof2 branch 3 times, most recently from b7f48ab to 36b8ac6 Compare February 6, 2025 20:18
@danielvegamyhre danielvegamyhre force-pushed the powerof2 branch 2 times, most recently from 03256a6 to e9d4ce5 Compare February 6, 2025 20:41
@danielvegamyhre danielvegamyhre force-pushed the powerof2 branch 3 times, most recently from 11d2b45 to 7f25616 Compare February 6, 2025 22:06
@vkuzo
Copy link
Contributor

vkuzo commented Feb 7, 2025

looks good!

it would also be great to verify loss curves still match bfloat16 on LLaMa 3 8B via torchtitan - I think they should, but good to verify for any numerics related change

@danielvegamyhre
Copy link
Contributor Author

danielvegamyhre commented Feb 7, 2025

looks good!

it would also be great to verify loss curves still match bfloat16 on LLaMa 3 8B via torchtitan - I think they should, but good to verify for any numerics related change

Ran Llama3 8b training runs on 8 H100s, comparing bf16 eager vs float8 rowwise compiled, loss curves virtually identical:

Screenshot 2025-02-07 at 9 18 00 AM

@danielvegamyhre
Copy link
Contributor Author

Hmm, I can't seem to reproduce the CI issue locally. I copy pasted the env setup commands from CI container logs and ran the test:

conda create -n venv python=3.9 -y
conda activate venv
python -m pip install --upgrade pip
pip install torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121
pip install -r dev-requirements.txt
pip install .

pytest test/float8/test_float8_utils.py 

============================================== test session starts ==============================================
platform linux -- Python 3.9.21, pytest-7.4.0, pluggy-1.5.0
rootdir: /data/users/danvm/ao
plugins: hypothesis-6.125.2
collected 17 items                                                                                              

test/float8/test_float8_utils.py .................                                                        [100%]

=============================================== warnings summary ================================================
test/float8/test_float8_utils.py::test_non_float32_input[invalid_dtype3]
test/float8/test_float8_utils.py::test_non_float32_input[invalid_dtype4]
test/float8/test_float8_utils.py::test_non_float32_input[invalid_dtype5]
test/float8/test_float8_utils.py::test_non_float32_input[invalid_dtype6]
test/float8/test_float8_utils.py::test_non_float32_input[invalid_dtype7]
  /data/users/danvm/ao/test/float8/test_float8_utils.py:59: DeprecationWarning: an integer is required (got type float).  Implicit conversion to integers using __int__ is deprecated, and may be removed in a future version of Python.
    non_float32_tensor = torch.tensor([3.0], dtype=invalid_dtype)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================== 17 passed, 5 warnings in 2.43s =========================================

@danielvegamyhre
Copy link
Contributor Author

danielvegamyhre commented Feb 7, 2025

@vkuzo I think to unblock Less, we should comment out the test case and merge this, since the issue is not related to the rounding function. Then we can continue investigating the root cause.

To clarify what the exact issue is: just simply creating tensor with the fp32 largest subnormal value in the problematic CI env truncates it to 0, before it can even passed into the rounding function.

As mentioned in #1670 (comment) the issue does not reproduce locally even when recreating the CI env as closely as possible, so I'm still investigating.

@danielvegamyhre
Copy link
Contributor Author

@vkuzo Updated this PR to include using e4m3 dtype for all inputs/weights/grad outputs for float8 rowwise recipe.

Confirmed training accuracy is not negatively impacted by comparing to bf16 eager baseline. (ran bf16 for 330 steps and fp8 for 110 steps but you can see the loss curves are virtually identical).

Screenshot 2025-02-10 at 9 45 28 AM

cc @lessw2020

@danielvegamyhre danielvegamyhre changed the title Support power of 2 scaling factors in float8 training Support power of 2 scaling factors in float8 training and use e4m3 everywhere Feb 10, 2025
@vkuzo
Copy link
Contributor

vkuzo commented Feb 10, 2025

I think this is safe to land, thanks for adding this!

@danielvegamyhre danielvegamyhre merged commit 32a51ec into pytorch:main Feb 10, 2025
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants