-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Triton] Add tl.gather
with a naive codegen implementation
#5262
base: main
Are you sure you want to change the base?
Conversation
tl.gather
with a naive codegen implementationtl.gather
with a naive codegen implementation
// CHECK-LABEL: @gather_op | ||
tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32>) -> tensor<512x4xf32> { | ||
// CHECK-NEXT: %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512x4xf32> | ||
%0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512x4xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just starting to look at this but shouldn't the index tensor be a 1D tensor if we index only along 1 dimension?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gather along a single axis d
means that each column along dim d
in the output is comprised of the elements from the corresponding column in the source tensor. E.g. out[i,j] = src[idx[i,j],j]
for axis=0
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah I see, ok yeah looks like that's what pytorch. I'll let @apgoucher confirm that it is what he wants but makes sense to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! It's probably worth having @apgoucher check that the semantic is what he had in mind but other than that looks good to go
} | ||
|
||
// Synchronize the whole CTA. | ||
// TODO(jeff): Should we teach Membar that gather synchronizes? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Membar cannot insert any "internal" synchronization barriers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't super important, but what I mean is we can teach membar that certain ops implicitly act as a synchronization, which causes the analysis to reset pending memory transactions up to those ops.
assert index.type.shape[d] <= src.type.shape[ | ||
d], f"index dim {axis} cannot be greater than the corresponding source dim" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit strange. You're allowing the gather op to implicitly slice the src
tensor to match the index
tensor? If we're going to allow this I think it should be its own operation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that's what we wanted?
I guess broadcasting could be supported
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that's what we wanted?
I 100% expect that the gather axis can be a different shape, but's it's not normal for the other dimensions to be allowed to be a different shape. I find it very surprising coming from numpy/pytorch semantics.
Also I don't think this behavior is compatible with broadcasting as it would be ambiguous. If the index has a dimension of size 1 we can't tell if it's supposed to be a slice, or if it should be broadcasted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose there is an advantage to fusing the gather op with the slice as in general I think a slice op could have to go through shared memory to transfer redundant data. Perhaps this could be a pattern matched lowering instead of implicit behavior of tt.gather
though?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Huh I guess torch.gather
actually does have this behavior. Now I'm not sure what to think haha. It feels wrong that there are huge chunks of the input tensor that get completely ignored, and feels to me like two operations.
I'm also a bit confused what the use case for this would be, as there's no way to create a slice that doesn't start at 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really have an opinion here on the semantics of tl.gather
, so let me know what you two prefer!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually I read too fast. I agree that I would expect the other dimensions to match the input dimension. Unless we have a specific use for it I think we should restrict the dimensions to match.
assert index.dtype.is_int(), "index must be an integer tensor" | ||
|
||
rank = len(src.type.shape) | ||
assert len(index.type.shape) == rank, "source and index tensors must have the same rank" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice to support broadcasting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate on what the broadcasting semantics would be?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, misclicked on the approval.
This PR adds a
tl.gather
builtin that implements a local gather along a single axis, with semantics matchingtorch.gather
.tl.gather
generates att.gather
op, which is piped through the compiler mostly untouched at the moment, since the codegen is very naive.The
tt.gather
is implemented by writing the source tensor into shared memory and then performing a gather out of shared memory, thus it requires scratch space to be allocated. In a follow-up, I will implement an optimized layout rule for the op that ensures the gather axis fits into a single warp, allowing the gather to be implemented using warp shuffles.There are other avenues for optimization as well:
tt.gather(tt.load)
where the load only has one use can be lowered into a DMA from global memory to shared, and then gather directly from shared.