Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand layerwise upcasting with optional white-list to allow Torch/GPU to perform native fp8 ops where possible #10635

Open
vladmandic opened this issue Jan 23, 2025 · 4 comments
Labels
enhancement New feature or request performance Anything related to performance improvements, profiling and benchmarking wip

Comments

@vladmandic
Copy link
Contributor

PR #10347 adds native torch fp8 as storage dtype and performs upcasting/downcasting to compute dtype in pre-forward/post-forward as needed.

however, modern gpu architectures (starting with hopper in 2022) actually do implement many ops natively in fp8.
and torch is extending number of supported ops in each release.

right now, we're still pretty far from being able to execute everything natively in fp8, but request here is to allow white-list of layers for which upcast/downcast can be skipped as operations can actually be executed natively.

that list would need to be per-gpu architecture, so its unlikely to be a static one - but just being able to specify couple of most common layers that take most of compute time would be very beneficial.

cc @a-r-r-o-w @sayakpaul @DN6

@a-r-r-o-w a-r-r-o-w added enhancement New feature or request wip performance Anything related to performance improvements, profiling and benchmarking labels Jan 23, 2025
@a-r-r-o-w
Copy link
Member

Thanks for starting the discussion @vladmandic! I think we will definitely be looking into fp8 matmul support atleast. It is known to work quite well on Ada and Hopper for a while now, so there is good signal that it will be a nice feature.

Is something like this what you're referring to support?

With the current ~diffusers.hooks.layerwise_casting.apply_layerwise_casting implementation, it should be possible to directly specify storage_dtype and compute_dtype as torch.float8*, since it can take any torch.nn.Module and apply the pre/post-forward stuff (given the layer it is applied to uses supported fp8 ops only). This would require digging through the modeling yourself to find the layers where to apply this, which does seem inconvenient for end users - we can give this some thought after current release schedule is complete.

@vladmandic
Copy link
Contributor Author

Is something like this what you're referring to support?

yes, but thats in a long run.
in a short run, i was thinking that some layers might just work as-is in fp8 without upcast in pre-forward.
i don't want to specify compute dtype per layer, that's a nightmare.

instead something like this?
we can control which layers should not get converted using this:

    skip_modules_pattern: Union[str, Tuple[str, ...]] = "auto",
    skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,

and ask here is to implement similar for layers that should get converted, but not upcast during compute.

@a-r-r-o-w
Copy link
Member

a-r-r-o-w commented Jan 23, 2025

I see, and that should be more doable for this release. Do you have certain layers in mind where this would be beneficial/work with fp8 ops? I could give it a try and check the impact on quality/speed (generally, we want to be careful about the features exposed if they can have a negative impact on quality).

@vladmandic
Copy link
Contributor Author

Do you have certain layers in mind where this would be beneficial/work with fp8 ops

not really, but i was thinking of doing some profiling with torch and simply try on the costliest ones.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request performance Anything related to performance improvements, profiling and benchmarking wip
Projects
None yet
Development

No branches or pull requests

2 participants