Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add new branch heyi_cast_transpose and update cast_transpose optimizations #89

Open
wants to merge 9 commits into
base: dev
Choose a base branch
from

Conversation

eliotwang
Copy link

@eliotwang eliotwang commented Nov 4, 2024

Description

Optimize the cast_transpose kernel by:

  1. refactoring the kernel to support more flexible parameter tuning methods;
  2. use assembly instructions to optimize the conversion for fp8 output types;
  3. apply performance analysis on different input shapes and provide empirical parameter configuration methods, ensuring that the HIP kernel's performance exceeds Triton's performance in all tests;

Steps to run:

docker pull rocm/pytorch:latest

docker run -it --network=host -v /home/yigex/heyi:/workspace --device=/dev/kfd --device=/dev/dri --group-add video --security-opt seccomp=unconfined --ipc=host --name heyi_te_wx rocm/pytorch:latest

mkdir heyi && cd heyi

git clone -b heyi_cast_transpose --recursive https://github.com/eliotwang/TransformerEngine.git

cd TransformerEngine

export NVTE_FRAMEWORK=pytorch
export NVTE_ROCM_ARCH=gfx942

pip install .

mkdir tests/cpp/build && cd tests/cpp/build

cmake .. && make

./operator/test_operator

[ RUN ] OperatorTest/CTTestSuite.TestCastTranspose/bfloat16Xfloat8e5m2X2048X12288
GPU execution time: 54.7456 us
[ OK ] OperatorTest/CTTestSuite.TestCastTranspose/bfloat16Xfloat8e5m2X2048X12288 (2854 ms)
[ RUN ] OperatorTest/CTTestSuite.TestCastTranspose/bfloat16Xfloat8e5m2X256X65536
GPU execution time: 39.4748 us
[ OK ] OperatorTest/CTTestSuite.TestCastTranspose/bfloat16Xfloat8e5m2X256X65536 (1690 ms)

Performance:

image


nvte_cast_transpose(input.data(), output_c.data(), output_t.data(), 0);

int warm_iter = 3;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest that we keep this test file as is. This file was just for unit testing. We can measure performance when running it with rocprof.

const size_t num_blocks = kernel_config.num_blocks;

size_t load_size;
size_t store_size;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would like to keep the original NVTE CUDA code while adding the new code for ROCm. You can use __HIP_PLATFORM_AMD__ to guard the code. Here is an example: https://github.com/ROCm/TransformerEngine/blob/dev/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu#L72-L80

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Line 280 to 286 should also be guarded by #ifdef HIP_PLATFORM_AMD.

Basically, after removing the rocm/amd specific ifdefs, we would like to see our repo are exactly the same as the upstream NVTE

constexpr size_t block_size = __BLOCK_SIZE__;

} // namespace

__device__ OType convert_from_fp32(float v) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function appears to be more complete. When compiling with the gfx940 macro enabled, it results in an error related to function overloading

@eliotwang
Copy link
Author

  1. Revert the tests/cpp/operator/test_cast_transpose.cu to its original version;
  2. Add HIP_PLATFORM_AMD to guard modified code in transformer_engine/common/transpose/cast_transpose.cu;
  3. Modify hip_f8 funcs in hip_float8.h to ensure the code runs correctly on MI308X;
  4. Adopt hip_f8 funcs in hip_float8.h to do precision conversion and remove custom conversion functions;

@BruceXcluding BruceXcluding requested review from BruceXcluding and wangye805 and removed request for BruceXcluding November 12, 2024 04:25
@eliotwang
Copy link
Author

fix cast_transpose issue where the width and height of the tensor cannot be divided by the tile size. pass pytest including test_float8tensor.py,test_numerics.py,fused_attn/test_fused_attn.py and test_sanity.py

…version, add __HIP_PLATFORM_AMD__ to guard modified code in transformer_engine/common/transpose/cast_transpose.cu, use hip_float8 implementation and remove custom conversion functions
@eliotwang eliotwang force-pushed the heyi_cast_transpose branch from 0acaf23 to 54d606a Compare January 7, 2025 13:42
Copy link
Contributor

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late review.

I was only able to review the cast_transpose part, not the cast_transpose_fusion part. I gave some comments on the coding style. But I'm still quite confused about the big picture as how wpt_size, iter_size works.

const size_t num_blocks = kernel_config.num_blocks;

size_t load_size;
size_t store_size;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Line 280 to 286 should also be guarded by #ifdef HIP_PLATFORM_AMD.

Basically, after removing the rocm/amd specific ifdefs, we would like to see our repo are exactly the same as the upstream NVTE

bool do_general_config = true;

#ifdef __HIP_PLATFORM_AMD__
if((std::is_same<OutputType, fp8e5m2>::value) || (std::is_same<OutputType, fp8e4m3>::value)){
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can use if constexpr since OutputType is determined at compile time

auto get_n_tiles = [=] (size_t load_size, size_t store_size) -> int {
constexpr size_t threads_per_warp = static_cast<size_t>(THREADS_PER_WARP);
size_t nvec_in = load_size / sizeof(InputType);
size_t nvec_out = store_size / sizeof(OutputType);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the sizeof(OutputType) is already fixed under this if condition

size_t nvec_out = store_size / sizeof(OutputType);
size_t n_tiles = DIVUP(row_length, nvec_in * threads_per_warp) *
DIVUP(num_rows, nvec_out * threads_per_warp);
return n_tiles;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be the same indent as line 301

auto get_n_blocks = [=] (size_t n_tiles, size_t cast_transpose_num_threads, size_t wpt_size) -> int {
size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP;
size_t n_blocks = DIVUP(n_tiles * wpt_size, n_warps_per_block);
return n_blocks;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indent issue

// Number of CUDA blocks
num_blocks = (row_length / row_tile_elements) * (num_rows / col_tile_elements);
rtc_block_size = THREADS_PER_WARP * wpt_size;
do_general_config =!(row_length % row_tile_elements == 0 && num_rows % col_tile_elements == 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do_general_config is just not aligned()?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under the AMD framework, if the output type is FP8 and the current tile_size meets certain conditions, the optimized configuration will be used; otherwise, it will fall under the do_general_config case.

kernel_configs.emplace_back(row_length, num_rows, itype_size, otype_size, load_size,
store_size, sm_count);
};
add_config(8, 8);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the big picture, in my opinion, you are trying to introduce a wpt_size and a iter_size so that wpt_size * iter_size <= THREADS_PER_WARP, in contrast to warps_per_tile x num_interations = THREADS_PER_WARP. So that each thread can save some registers and overall shared mems. Is my understanding correct?

If so, I would expect add_config() will add wpt_size and iter_size into the cost model. But here we don't see additional parameters for add_config. Why?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, when wpt_size * iter_size <= THREADS_PER_WARP, it can save some registers and overall shared memory, thus improving the occupancy of the cu. How can we reflect the impact of wpt_size and iter_size on performance improvement in the cost model?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Emm, we probably first need to understand how latency is affected by wpt_size*iter_size before changing the cost model.

By the way, from https://github.com/eliotwang/TransformerEngine/blob/b9359d65666bbcbd6734376e180c8fbe513c50ee/transformer_engine/common/transpose/cast_transpose.cu#L312, it seems that wpt_size*iter_size is still tied to THREADS_PER_WARP, which is 32?

};

wpt_size = 8;
iter_size = THREADS_PER_WARP / wpt_size;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if iter_size x wpt_size = THREADS_PER_WARP here, why do we need to add ITER_SIZE into the subsequent RTC params?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an added parameter for optimization, aimed at testing the impact of different wpt_size and iter_size configurations on performance.

…ization implementation for cast_transpose and cast_transpose_fusion. Organize the newly added code within the management scope of __HIP_PLATFORM_AMD__.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants