Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix invalid IR from scalarrepl-param-hlsl in ReplaceConstantWithInst #6556

Conversation

amaiorano
Copy link
Collaborator

ReplaceConstantWithInst(C, V) replaces uses of C in the current function with V. If such a use C is an instruction I, the it replaces uses of C in I with V. However, this function did not make sure to only perform this replacement if V dominates I. As a result, it may end up replacing uses of C in instructions before the definition of V.

The fix is to lazily compute the dominator tree in ReplaceConstantWithInst so that we can guard the replacement with that dominance check.

@amaiorano amaiorano requested a review from a team as a code owner April 22, 2024 21:24
@amaiorano
Copy link
Collaborator Author

Note that this was a joint effort with @dneto0. See the test file in the commit. Also see this compiler explorer link with the same example that segfaults.

The following explains the fix:

The bug is in ReplaceConstantWIthInst(Constant *C, Value *V, Builder..).

Here's the relevant IR:

while.end.i:                                      ; preds = %while.body.i
  %10 = bitcast [10 x i32]* %arr1_copy.i to i8*, !dbg !33
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* %10, i8* bitcast ([10 x i32]* @arr1 to i8*), i64 40, i32 1, i1 false) #0, !dbg !33
  %11 = bitcast [10 x i32]* %arr1_copy.i to i8*, !dbg !34
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* bitcast ([10 x i32]* @arr1 to i8*), i8* %11, i64 40, i32 1, i1 false) #0, !dbg !34
  store i32 0, i32* %i.i.1.i, align 4, !dbg !35, !tbaa !12
  %12 = load %struct.ByteAddressBuffer, %struct.ByteAddressBuffer* @"\01?buff@@3UByteAddressBuffer@@A", !dbg !37

There are two memcpys. (dest is first arg, src is second arg)

  • ReplaceMemcpy(Value=@arr1,Src=%11) tries to replace the second memcpy: roughly memcpy(@arr1, %11) a.k.a. memcpy(@arr1, %arr1_copy.i), where %arr1_copy.i is the function-local array variable.
  • It sees that Value is a constant (because its the address of global var @arr)
  • It sees that Src is not a constant, because it's the result of an alloca (a.k.a. stack variable address)
  • So it falls into the code path with the comment // Replace Constant with a non-Constant. and calls ReplaceConstantWithInst

Now let's enter ReplaceConstantWithInst(C=@arr, V=%arr1_copy.i). Its job is to replace uses of C, in the current function, with the value V. It traverses the uses of C, skipping over any that aren't in the current function (i.e. the function containing the instruction that generates V). If such a use C is an instruction I, then it replaces uses of C in I with V. Remember C=@arr, and one of the uses is the first memcpy: memcpy(bitcast of %arr1_copy.i, bitcast of @arr).

After these replacements, the first part of that basic block is this:

while.end.i:                                      ; preds = %while.body.i
  %10 = bitcast [10 x i32]* %arr1_copy.i to i8*, !dbg !32
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* %10, i8* %12, i64 40, i32 1, i1 false) #0, !dbg !32
  %11 = bitcast [10 x i32]* %arr1_copy.i to i8*, !dbg !33
  %12 = bitcast [10 x i32]* %arr1_copy.i to i8*, !dbg !33
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* %12, i8* %11, i64 40, i32 1, i1 false) #0, !dbg !33
  store i32 0, i32* %i.i.1.i, align 4, !dbg !34, !tbaa !12

We see that %12 is in the first memcpy, but defined two instructions later. That's bad.

The bug is in this statement: If such a use C is an instruction I, then it replaces uses of C in I with V. That's only safe to do if V dominates I.

So the fix uses a big hammer: compute the dominator tree so we can guard the replacement with that dominance check. Although it may seem heavyweight to recompute the dominator tree repeatedly, it's the safe thing to do. Earlier in the code it says:

 // SROA_Parameter_HLSL has no access to a domtree, if one is needed, it'll
 // be generated

So it seems in the spirit to compute the dominator tree on demand in this pass.

Note that early in ReplaceMemcpy, it has a promising comment and mechanism to early-out:

  // If the source of the memcpy (Src) doesn't dominate all users of dest (V),
  // full replacement isn't possible without complicated PHI insertion
  // This will likely replace with ld/st which will be replaced in mem2reg
  if (Instruction *SrcI = dyn_cast<Instruction>(Src))
    if (!DominateAllUsers(SrcI, V, DT))
      return false;

But DominateAllUsers doesn't quite do the job for us. It calls DominateAllUsersDom (where it has a dominator tree), and it does this code:

// Use `DT` to trace all users and make sure `I`'s BB dominates them all
static bool DominateAllUsersDom(Instruction *I, Value *V, DominatorTree *DT) {
  BasicBlock *BB = I->getParent();
  Function *F = I->getParent()->getParent();
  for (auto U = V->user_begin(); U != V->user_end();) {
    Instruction *UI = dyn_cast<Instruction>(*(U++));
    // If not an instruction or from a differnt function, nothing to check, move
    // along.
    if (!UI || UI->getParent()->getParent() != F)
      continue;

    if (!DT->dominates(BB, UI->getParent()))
      return false;
      
    if (isa<GetElementPtrInst>(UI) || isa<BitCastInst>(UI)) {
      if (!DominateAllUsersDom(I, UI, DT))
        return false;
    }
  }
  return true;
}

First, in our case value V is a constant (@arr), and its use in the first memcpy is i8* bitcast ([10 x i32]* @arr1 to i8* which is itself an llvm::Constant. So it early outs because UI here is null because that llvm::Constant is not an instruction. So I think it returns true, and that's why we get into trouble later.

One local fix here could be if the use is a constant, then recurse and follow the users of that enclosing constant, until you hit an instruction. This has a problem that it could blow up the search.

And even if you got out far enough to reach an instruction you would get to the first memcpy. And then the test if (!DT->dominates(BB, UI->getParent())) only compares at a basic block level. But both memcpys are in the same basic block, and a node always dominates itself. So again it would decide incorrectly. You have to check dominance on an instruction level.

Copy link
Collaborator

@llvm-beanz llvm-beanz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid the tools/clang/test/DXC directory becoming a dumping ground, can you move these tests into a new subdirectory, maybe tools/clang/test/DXC/Passes/ScalarReplHLSL?

lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp Outdated Show resolved Hide resolved
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp Outdated Show resolved Hide resolved
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp Outdated Show resolved Hide resolved
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp Outdated Show resolved Hide resolved
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp Outdated Show resolved Hide resolved
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp Outdated Show resolved Hide resolved
Copy link
Member

@pow2clk pow2clk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to understand why the existing dominance check was insufficient.

lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp Outdated Show resolved Hide resolved
@amaiorano
Copy link
Collaborator Author

@pow2clk, you asked good questions yesterday about why this change is needed. Here's a deeper dive into what's happening when we run DXC against the example code, scalarrepl-param-hlsl-const-to-local-and-back.hlsl. Hopefully this makes the rationale for the patch more clear.

You asked about why the initial call to DominateAllUsers wasn't good enough to determine that we shouldn't proceed with replacing the memcpy. I set a breakpoint in DominateAllUsers.

Call stack:

> dxcompiler.dll!DominateAllUsers(llvm::Instruction * I=0x0000028f0f546768, llvm::Value * V=0x0000028f0d72a318, llvm::DominatorTree * DT=0x0000000000000000) Line 3857	C++
 	dxcompiler.dll!ReplaceMemcpy(llvm::Value * V=0x0000028f0d72a318, llvm::Value * Src=0x0000028f0f546768, llvm::MemCpyInst * MC=0x0000028f0f546da0, hlsl::DxilFieldAnnotation * annotation=0x0000000000000000, hlsl::DxilTypeSystem & typeSys={...}, const llvm::DataLayout & DL={...}, llvm::DominatorTree * DT=0x0000000000000000) Line 3556	C++
 	dxcompiler.dll!`anonymous namespace'::SROA_Helper::LowerMemcpy(llvm::Value * V=0x0000028f0d72a318, hlsl::DxilFieldAnnotation * annotation=0x0000000000000000, hlsl::DxilTypeSystem & typeSys={...}, const llvm::DataLayout & DL={...}, llvm::DominatorTree * DT=0x0000000000000000, bool bAllowReplace=true) Line 4074	C++
 	dxcompiler.dll!`anonymous namespace'::SROAGlobalAndAllocas(hlsl::HLModule & HLM={...}, bool bHasDbgInfo=false) Line 1950	C++
 	dxcompiler.dll!`anonymous namespace'::SROA_Parameter_HLSL::runOnModule(llvm::Module & M={...}) Line 4381	C++
 	dxcompiler.dll!`anonymous namespace'::MPPassManager::runOnModule(llvm::Module & M={...}) Line 1669	C++
...

Looking at the call stack, we see that we are processing globals and allocas:

GV->dump()
@arr1 = internal global [10 x i32] zeroinitializer, align 4

Which corresponds to:

static int arr1[10] = (int[10])0;

In SROA_Helper::LowerMemcpy, we determine that GV has a non-undef initializer, and is MemcopyDestOnce, so we will replace the source of the memcpy. The memcpy and Src are:

MC->dump()
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* bitcast ([10 x i32]* @arr1 to i8*), i8* %11, i64 40, i32 1, i1 false) #0, !dbg !34

Src->dump()
  %arr1_copy.i = alloca [10 x i32], align 4

Note that Src is an llvm::AllocaInst of arr1_copy.

In HLSL, this refers to:

  int arr1_copy[10] = arr1; // constant to local
  arr1 = arr1_copy; // local to constant

In IR, this looks like:

@arr1 = internal global [10 x i32] zeroinitializer, align 4
...

; Function Attrs: nounwind
define void @main(<4 x float>* noalias) #0 {
entry:
...
  %arr1_copy.i = alloca [10 x i32], align 4
...
while.end.i:                                      ; preds = %while.body.i
  %10 = bitcast [10 x i32]* %arr1_copy.i to i8*, !dbg !33
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* %10, i8* bitcast ([10 x i32]* @arr1 to i8*), i64 40, i32 1, i1 false) #0, !dbg !33
  %11 = bitcast [10 x i32]* %arr1_copy.i to i8*, !dbg !34
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* bitcast ([10 x i32]* @arr1 to i8*), i8* %11, i64 40, i32 1, i1 false) #0, !dbg !34
...

We're currently lowering the 2nd memcpy - the one from arr1_copy to arr1.

In SROA_Helper::LowerMemcpy we are here, in the call to ReplaceMemcpy:

        } else if (!isa<CallInst>(Src)) {
          // Resource ptr should not be replaced.
          // Need to make sure src not updated after current memcpy.
          // Check Src only have 1 store now.
          // If Src has more than 1 store but only used once by memcpy, check if
          // the stores dominate the memcpy.
          hlutil::PointerStatus SrcPS(Src, size, /*bLdStOnly*/ false);
          SrcPS.analyze(typeSys, bStructElt);
          if (SrcPS.storedType != hlutil::PointerStatus::StoredType::Stored ||
              (SrcPS.loadedType ==
                   hlutil::PointerStatus::LoadedType::MemcopySrcOnce &&
               allStoresDominateInst(Src, MC, DT))) {
            if (ReplaceMemcpy(V, Src, MC, annotation, typeSys, DL, DT)) { <------------ HERE
              if (V->user_empty())
                return true;
              return LowerMemcpy(V, annotation, typeSys, DL, DT, bAllowReplace);
            }
          }

It's true that Src (arr1_copy) is not updated after the current memcpy, and that all stores of Src dominate the memcpy, as the only store to arr1_copy is made just before it from arr1 (the first memcpy). Here's SrcPS:

SrcPS
{storedType=MemcopyDestOnce (3) loadedType=MemcopySrcOnce (1) StoredOnceValue=0x0000000000000000 <NULL> ...}
    storedType: MemcopyDestOnce (3)
    loadedType: MemcopySrcOnce (1)
    StoredOnceValue: 0x0000000000000000 <NULL>
    memcpySet: {set_={...} vector_={ size=2 } }
    StoringMemcpy: 0x0000028f0f546cb0 68
    LoadingMemcpy: 0x0000028f0f546da0 68
    AccessingFunction: 0x0000028f0f483a18 FunctionVal (2)
    HasMultipleAccessingFunctions: false
    Size: 40
    Ptr: 0x0000028f0f546768 45
    bLoadStoreOnly: false

So now we're in ReplaceMemcpy.

  // If the source of the memcpy (Src) doesn't dominate all users of dest (V),
  // full replacement isn't possible without complicated PHI insertion
  // This will likely replace with ld/st which will be replaced in mem2reg
  if (Instruction *SrcI = dyn_cast<Instruction>(Src))
    if (!DominateAllUsers(SrcI, V, DT))
      return false;

So this is supposed to check whether Src, the AllocaInst of arr1_copy (%arr1_copy.i = alloca [10 x i32], align 4), dominates all users of the dest value V, arr1 (@arr1 = internal global [10 x i32] zeroinitializer, align 4). If it doesn't, we bail on full replacement.

At the very start of the call to DominateAllUsers, there is this check:

  // The Entry Block dominates everything, trivially true
  if (&F->getEntryBlock() == I->getParent())
    return true;

Indeed, the alloca of arr1_copy is in the entry block of the function F:

; Function Attrs: nounwind
define void @main(<4 x float>* noalias) #0 {
entry:
  %agg.result.0 = alloca <4 x float>
  %1 = alloca float
  store float 0.000000e+00, float* %1
  %i.i.1.i = alloca i32, align 4
  %i.i.i = alloca i32, align 4
  %cond.i = alloca i32, align 4
  %arr1_copy.i = alloca [10 x i32], align 4 <------ HERE

I assume the reason for this is that allocas are usually the first thing in a function's entry block, so if it's there, it necessarily dominates all users of the memcpy's destination.

Our use-case is one where there are 2 memcpys, where the first copies from arr1 to arr1_copy, then the second from arr1_copy back to arr1:

  %10 = bitcast [10 x i32]* %arr1_copy.i to i8*, !dbg !33
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* %10, i8* bitcast ([10 x i32]* @arr1 to i8*), i64 40, i32 1, i1 false) #0, !dbg !33
  %11 = bitcast [10 x i32]* %arr1_copy.i to i8*, !dbg !34
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* bitcast ([10 x i32]* @arr1 to i8*), i8* %11, i64 40, i32 1, i1 false) #0, !dbg !34

We're processing the second memcpy's destination, arr1, and sure enough, the src, arr1_copy does indeed dominate all uses of arr1, including that first memcpy. So we return true from DominateAllUsers.

The problem isn't the fact that Src doesn't dominate uses of Dest, but rather, than in replacing Dest arr1 with Src arr1_copy, we end up doing so to the memcpy that comes before the one we're currently processing. The bug is that without any changes, the following:

while.end.i:                                      ; preds = %while.body.i
  %10 = bitcast [10 x i32]* %arr1_copy.i to i8*, !dbg !33
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* %10, i8* bitcast ([10 x i32]* @arr1 to i8*), i64 40, i32 1, i1 false) #0, !dbg !33
  %11 = bitcast [10 x i32]* %arr1_copy.i to i8*, !dbg !34
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* bitcast ([10 x i32]* @arr1 to i8*), i8* %11, i64 40, i32 1, i1 false) #0, !dbg !34

Turns into this, right after the call to ReplaceConstantWithInst(C = arr1, Src = alloca of arr1_copy):

while.end.i:                                      ; preds = %while.body.i
  %10 = bitcast [10 x i32]* %arr1_copy.i to i8*, !dbg !32
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* %10, i8* %12, i64 40, i32 1, i1 false) #0, !dbg !32
  %11 = bitcast [10 x i32]* %arr1_copy.i to i8*, !dbg !33
  %12 = bitcast [10 x i32]* %arr1_copy.i to i8*, !dbg !33
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* %12, i8* %11, i64 40, i32 1, i1 false) #0, !dbg !33

Note that the first call to memcpy changed so that it's source went from i8* bitcast ([10 x i32]* @arr1 to i8*) to the new %12, but %12 is defined below it, which is bad.

Now let's see why does ReplaceConstantWithInst end up wanting to replace the source of that first memcpy with %12 defined below it. In ReplaceConstantWithInst(C = arr1, Src = alloca of arr1_copy, Builder), we iterate over the users of C, arr1. At one point, one such user is:

U->dump()
i8* bitcast ([10 x i32]* @arr1 to i8*)

This use is the one shown above that is the source of that first memcpy. This is not, itself, an llvm::Instruction, but is, instead, an llvm::UnaryConstantExpr. As such, we fall into the else branch of the user iteration loop:

static bool ReplaceConstantWithInst(Constant *C, Value *V,
                                    IRBuilder<> &Builder) {
  bool replacedAll = true;
  Function *F = Builder.GetInsertBlock()->getParent();
  Instruction *VInst = dyn_cast<Instruction>(V);
  DominatorTree DT;
  if (VInst) {
    DT.recalculate(*F);
  }
  for (auto it = C->user_begin(); it != C->user_end();) {
    User *U = *(it++);
    if (Instruction *I = dyn_cast<Instruction>(U)) {
      if (I->getParent()->getParent() != F)
        continue;
      if (VInst && DT.dominates(VInst, I)) {
        I->replaceUsesOfWith(C, V);
      } else {
        replacedAll = false;
      }
    } else {
      <----------------------------------- HERE
      // Skip unused ConstantExpr.
      if (U->user_empty())
        continue;
      ConstantExpr *CE = cast<ConstantExpr>(U);
      Instruction *Inst = CE->getAsInstruction();
      Builder.Insert(Inst);
      Inst->replaceUsesOfWith(C, V);
      if (!ReplaceConstantWithInst(CE, Inst, Builder)) {
        replacedAll = false;
      }
    }
  }
  C->removeDeadConstantUsers();
  return replacedAll;
}

Since the use, U, is a llvm::UnaryConstantExpr, we use getAsInstruction, which creates a new llvm::Instruction that "implements the same operation as the ConstantExpr", and does not belong to a BB. Right after the call to CE->getAsInstruction(), (and inserting into the Builder - more on this below) Inst looks like this:

Inst->dump()
  %12 = bitcast [10 x i32]* @arr1 to i8*, !dbg !33

We then call Inst->replaceUsesOfWith(C, V); on it to replace the use of arr1 with arr1_copy, and now it looks like this:

Inst->dump()
  %12 = bitcast [10 x i32]* %arr1_copy.i to i8*, !dbg !33

Before we recurse into ReplaceConstantWithInst, lets talk about where this new instruction got inserted. We end up inserting this instruction at the current position in the current BB, which is while.end.i, right before the 2nd memcpy, because the IRBuilder was built from that one in ReplaceMemcpy:

  if (Constant *C = dyn_cast<Constant>(V)) {
    updateLifetimeForReplacement(V, Src);
    if (TyV == TySrc) {
      if (isa<Constant>(Src)) {
        V->replaceAllUsesWith(Src);
      } else {
        // Replace Constant with a non-Constant.
        IRBuilder<> Builder(MC);   <---------------------------- HERE
        if (!ReplaceConstantWithInst(C, Src, Builder)) {
          return false;
        }
      }
    } else {

So this new instruction is inserted after the first memcpy in which we're going to replace it soon. We now recurse, calling ReplaceConstantWithInst(C = CE, V = Inst, Builder), so that we can replace the llvm::UnaryConstantExpr, that is bitcast [10 x i32]* @arr1 to i8*, with our new instruction, %12 = bitcast [10 x i32]* %arr1_copy.i to i8*. This is what ultimately leads to the first memcpy's source being changed to %12, which is bad.

So the bug is that we should not replace uses of a value C in I, with another one, V, if I dominates V; or conversely, we should only replace uses of value C in I with another one V if V dominates I. This is exactly what this patch does:

static bool ReplaceConstantWithInst(Constant *C, Value *V,
                                    IRBuilder<> &Builder) {
  bool replacedAll = true;
  Function *F = Builder.GetInsertBlock()->getParent();
  Instruction *VInst = dyn_cast<Instruction>(V);
  DominatorTree DT;
  if (VInst) {
    DT.recalculate(*F);
  }
  for (auto it = C->user_begin(); it != C->user_end();) {
    User *U = *(it++);
    if (Instruction *I = dyn_cast<Instruction>(U)) {
      if (I->getParent()->getParent() != F)
        continue;
      if (VInst && DT.dominates(VInst, I)) { <------------ HERE
        I->replaceUsesOfWith(C, V);
      } else {
        replacedAll = false;
      }
    } else {
      // Skip unused ConstantExpr.
      if (U->user_empty())
        continue;
      ConstantExpr *CE = cast<ConstantExpr>(U);
      Instruction *Inst = CE->getAsInstruction();
      Builder.Insert(Inst);
      Inst->replaceUsesOfWith(C, V);
      if (!ReplaceConstantWithInst(CE, Inst, Builder)) {
        replacedAll = false;
      }
    }
  }
  C->removeDeadConstantUsers();
  return replacedAll;
}

Now maybe there's another way to solve this by making sure to insert our new instruction above all uses? Maybe instead of creating a builder around the 2nd memcpy, we should create it around the alloca of the Src arr1_copy? I'm not sure, as there are three places in ReplaceMemcpy that creates builders, and they all do so around the current memcpy instruction.

Hopefully this clarifies what's going on with this bug, and why our patch works this way. Basically, while replacing constants, we may end up injecting new instructions above the current memcpy, but then have to be careful not to replace references to it in instructions before that memcpy.

ReplaceConstantWithInst(C, V) replaces uses of C in the current function with V.
If such a use C is an instruction I, the it replaces uses of C in I with V.
However, this function did not make sure to only perform this replacement if V
dominates I. As a result, it may end up replacing uses of C in instructions
before the definition of V.

The fix is to lazily compute the dominator tree in ReplaceConstantWithInst so
that we can guard the replacement with that dominance check.
@amaiorano amaiorano force-pushed the fix-SROA-replace-constant-with-non-constant branch from dd202d2 to 407ce70 Compare April 25, 2024 18:08
@amaiorano
Copy link
Collaborator Author

@llvm-beanz addressed all your comments and 👍 each one.

@amaiorano amaiorano merged commit 773b012 into microsoft:main May 6, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

3 participants