Expand layerwise upcasting with optional white-list to allow Torch/GPU to perform native fp8 ops where possible #10635
Labels
enhancement
New feature or request
performance
Anything related to performance improvements, profiling and benchmarking
wip
PR #10347 adds native torch fp8 as storage dtype and performs upcasting/downcasting to compute dtype in pre-forward/post-forward as needed.
however, modern gpu architectures (starting with hopper in 2022) actually do implement many ops natively in fp8.
and torch is extending number of supported ops in each release.
right now, we're still pretty far from being able to execute everything natively in fp8, but request here is to allow white-list of layers for which upcast/downcast can be skipped as operations can actually be executed natively.
that list would need to be per-gpu architecture, so its unlikely to be a static one - but just being able to specify couple of most common layers that take most of compute time would be very beneficial.
cc @a-r-r-o-w @sayakpaul @DN6
The text was updated successfully, but these errors were encountered: