Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Triton] Add tl.gather with a naive codegen implementation #5262

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

Mogball
Copy link
Collaborator

@Mogball Mogball commented Nov 26, 2024

This PR adds a tl.gather builtin that implements a local gather along a single axis, with semantics matching torch.gather. tl.gather generates a tt.gather op, which is piped through the compiler mostly untouched at the moment, since the codegen is very naive.

The tt.gather is implemented by writing the source tensor into shared memory and then performing a gather out of shared memory, thus it requires scratch space to be allocated. In a follow-up, I will implement an optimized layout rule for the op that ensures the gather axis fits into a single warp, allowing the gather to be implemented using warp shuffles.

There are other avenues for optimization as well: tt.gather(tt.load) where the load only has one use can be lowered into a DMA from global memory to shared, and then gather directly from shared.

@Mogball Mogball changed the title Add tl.gather with a naive codegen implementation [Triton] Add tl.gather with a naive codegen implementation Nov 26, 2024
// CHECK-LABEL: @gather_op
tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32>) -> tensor<512x4xf32> {
// CHECK-NEXT: %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512x4xf32>
%0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512x4xf32>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just starting to look at this but shouldn't the index tensor be a 1D tensor if we index only along 1 dimension?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gather along a single axis d means that each column along dim d in the output is comprised of the elements from the corresponding column in the source tensor. E.g. out[i,j] = src[idx[i,j],j] for axis=0.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I see, ok yeah looks like that's what pytorch. I'll let @apgoucher confirm that it is what he wants but makes sense to me.

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! It's probably worth having @apgoucher check that the semantic is what he had in mind but other than that looks good to go

}

// Synchronize the whole CTA.
// TODO(jeff): Should we teach Membar that gather synchronizes?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Membar cannot insert any "internal" synchronization barriers

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't super important, but what I mean is we can teach membar that certain ops implicitly act as a synchronization, which causes the analysis to reset pending memory transactions up to those ops.

lib/Dialect/Triton/IR/Ops.cpp Show resolved Hide resolved
lib/Dialect/Triton/IR/Ops.cpp Show resolved Hide resolved
Comment on lines +1698 to +1699
assert index.type.shape[d] <= src.type.shape[
d], f"index dim {axis} cannot be greater than the corresponding source dim"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit strange. You're allowing the gather op to implicitly slice the src tensor to match the index tensor? If we're going to allow this I think it should be its own operation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that's what we wanted?

I guess broadcasting could be supported

Copy link
Contributor

@peterbell10 peterbell10 Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that's what we wanted?

I 100% expect that the gather axis can be a different shape, but's it's not normal for the other dimensions to be allowed to be a different shape. I find it very surprising coming from numpy/pytorch semantics.

Also I don't think this behavior is compatible with broadcasting as it would be ambiguous. If the index has a dimension of size 1 we can't tell if it's supposed to be a slice, or if it should be broadcasted.

Copy link
Contributor

@peterbell10 peterbell10 Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose there is an advantage to fusing the gather op with the slice as in general I think a slice op could have to go through shared memory to transfer redundant data. Perhaps this could be a pattern matched lowering instead of implicit behavior of tt.gather though?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh I guess torch.gather actually does have this behavior. Now I'm not sure what to think haha. It feels wrong that there are huge chunks of the input tensor that get completely ignored, and feels to me like two operations.

I'm also a bit confused what the use case for this would be, as there's no way to create a slice that doesn't start at 0.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really have an opinion here on the semantics of tl.gather, so let me know what you two prefer!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually I read too fast. I agree that I would expect the other dimensions to match the input dimension. Unless we have a specific use for it I think we should restrict the dimensions to match.

assert index.dtype.is_int(), "index must be an integer tensor"

rank = len(src.type.shape)
assert len(index.type.shape) == rank, "source and index tensors must have the same rank"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to support broadcasting.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate on what the broadcasting semantics would be?

python/test/unit/language/test_core.py Outdated Show resolved Hide resolved
@peterbell10 peterbell10 self-requested a review November 27, 2024 01:23
Copy link
Contributor

@peterbell10 peterbell10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, misclicked on the approval.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants