Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the transpose scheduler for DID loop split. #3927

Open
wants to merge 5 commits into
base: wjy/rfactor
Choose a base branch
from

Conversation

wujingyue
Copy link
Collaborator

For #2563

@wujingyue
Copy link
Collaborator Author

!test

Copy link

github-actions bot commented Feb 20, 2025

Review updated until commit c43d5e1

Description

  • Rename getShapeInReference to getLoopDomainSizes

  • Update loop domain handling in transpose scheduler

  • Improve error handling in vectorization factor calculation

  • Fix tensor view string representation and allocation domain setting


Changes walkthrough 📝

Relevant files
Enhancement
transpose.cpp
Update transpose scheduler functions                                         

csrc/scheduler/transpose.cpp

  • Renamed getShapeInReference to getLoopDomainSizes
  • Updated loop domain handling in multiple functions
  • Improved error handling in getVectorizationFactorTransposeGroup
  • +16/-15 
    test_multidevice_sharding.cpp
    Update transpose test case                                                             

    tests/cpp/test_multidevice_sharding.cpp

  • Updated Transpose test to use symbolic tensors
  • Corrected tensor splitting and parallelization
  • Simplified expected output validation
  • +12/-12 
    Bug fix
    vectorize_helper.cpp
    Improve vectorization factor calculation                                 

    csrc/scheduler/vectorize_helper.cpp

  • Added ir/printer.h include
  • Updated getVectorizationFactorTransposeGroup to handle parallelized
    axes
  • +12/-4   
    tensor_view.cpp
    Fix tensor view string representation                                       

    csrc/tensor_view.cpp

  • Updated toString method to use indent
  • Corrected allocation domain setting in toString
  • +2/-2     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Function Renaming

    The function getShapeInReference has been renamed to getLoopDomainSizes. Ensure that this change is intentional and that all references to getShapeInReference have been updated accordingly.

    std::pair<std::vector<int64_t>, int64_t> getLoopDomainSizes(
        HeuristicDataCache* data_cache,
        SchedulerRuntimeInfo& runtime_info,
        TensorView* reference,
        scheduler_tools::TransposeDomainMap& domain_map) {
      auto ref_loop = reference->getLoopDomain();
    Vectorization Factor Check

    The code now includes a check to ensure that the innermost dimension is not parallelized before vectorization. Verify that this check is necessary and that it does not introduce any unintended behavior.

    }
    
    /////////////////////////////
    // Step 2: global schedule //
    /////////////////////////////
    
    // make tile
    Test Case Update

    The test case for transpose has been updated to use a symbolic tensor and specific split operations. Ensure that these changes do not alter the intended behavior of the test and that it still effectively validates the transpose scheduler.

    TEST_F(MultiDeviceTest, Transpose) {
      auto fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      const auto num_devices = communicator_->size();
      auto mesh = DeviceMesh::createForNumDevices(num_devices);
    
      TensorView* in = makeSymbolicTensor(2);
      TensorView* out = transpose(in, 0, 1);
      in->split(0, num_devices, /*inner_split=*/false);
      in->axis(0)->parallelize(ParallelType::DIDx);
      out->split(1, num_devices, /*inner_split=*/false);
      out->axis(1)->parallelize(ParallelType::DIDx);
      out->reorder({1, 0});
      for (auto* tv : {in, out}) {
        tv->setDeviceMesh(mesh);
        tv->setAllocationDomain(tv->getLoopDomain(), true);
      }
      fusion->addInput(in);

    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue
    Copy link
    Collaborator Author

    !test --diff

    @wujingyue wujingyue requested a review from naoyam February 21, 2025 07:02
    @wujingyue
    Copy link
    Collaborator Author

    The codegen diff is expected -- this PR changed the test.

    scheduler_utils::splitDims(reference1, tparams->split_before_tiling);
    scheduler_utils::splitDims(
    reference1, tparams->split_before_tiling, to_update);
    inner_most_pos1_in_ref1 = to_update[0];
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Is this change a general bug fix not specific to DID?

    Copy link
    Collaborator Author

    @wujingyue wujingyue Feb 21, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I think it's a bug triggered by DID loop split. According to #3927 (comment), inner_most_pos1_in_ref1 and inner_most_pos2_in_ref2 ought to be loop axis but is misused to access getLogicalDomain() in getVectorizationFactorTransposeGroup. Therefore, I also made this change: https://github.com/NVIDIA/Fuser/pull/3927/files#diff-364f51b47b1cf80c14b75d95ac2882238a65aa0594d588cff2fc54526307ca55L941

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I see. Looks like it was fine before because logical == loop. The change here makes sense.

    @wujingyue wujingyue requested a review from naoyam February 21, 2025 20:02
    Copy link
    Collaborator

    @naoyam naoyam left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    The change looks good to me. I'll let @zasdfgbnm to confirm.

    @zasdfgbnm Could you take a look as well?

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    2 participants