Proposal: Support a new data access pattern. #138
Replies: 8 comments 6 replies
-
Thank you @colawithsauce for your detailed write-up. I will get back to you as soon as possible. :) |
Beta Was this translation helpful? Give feedback.
-
Would you mind giving another explanation on what
What if Also, if we plug
What does this mean? It would be great if you could give a concrete example for a 2D tensor with explicit shapes. |
Beta Was this translation helpful? Give feedback.
-
I also have another question. All of the diagrams seem to assume that we're first dividing by |
Beta Was this translation helpful? Give feedback.
-
Aside from the above questions that I have, one major limitation that triton-shared has at the moment that will make supporting this pattern hard is that we don't support generating multiple loads from a single We cannot describe a memref load with a single offset, static strides and shapes for this memory load. This pattern would require 4 loads with offset {0, 1, 2, 3}. |
Beta Was this translation helpful? Give feedback.
-
@nhat-nguyen Let me explain An simple exampleWe can construct a simple example: tmp1 = tl.load(in_ptr + ((xindex // 1) % 4 ) * 1
+ ((xindex // 4) % 16) * 4) this example shows a linear matrix load (after ptr arith, index still be Here other example: tmp1 = tl.load(in_ptr + ((xindex // 64) % 1) * 1) this example shows a repeats, and x0 = xindex % 64
x2 = (xindex // 2048)
x3 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (64*x2)), None, eviction_policy='evict_last') this example shows the condition that The later example we introduce here, can be structured by adding a phamtom dimension. this load can be expressed by The former example we can also do this addition. we can adding a phamtom dimension_0. And the original dimension_0 is dimension_1 now. An complex example:import torch
def fn(x, y):
return torch.permute(x, (0, 2, 1, 3)) + y
fnc = torch.compile(fn)
bsz = 4
num_head = 32
seq_len = 2048
head_dim = 128
x = torch.randn([bsz, num_head, seq_len, head_dim]).cuda()
y = torch.randn([bsz, seq_len, num_head, head_dim]).cuda()
z = fnc(x, y)
print(z[0,0,0,0]) And the triton DSL it generated is: @triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 33554432
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x4 = xindex
x0 = xindex % 128
x1 = (xindex // 128) % 2048
x2 = (xindex // 262144) % 32
x3 = (xindex // 8388608)
tmp0 = tl.load(in_ptr0 + (x4), None)
tmp1 = tl.load(in_ptr1 + (x0 + (128*x2) + (4096*x1) + (8388608*x3)), None)
tmp2 = tmp0 + tmp1
tl.store(out_ptr0 + (x4), tmp2, None) we expand operands of load operation: # xindex = [0, 1, 2, ....]
tmp1 = tl.load (in_ptr1 + ((xindex // 1) % 128 ) * 1
+ ((xindex // (128*2048)) % 32 ) * 128
+ ((xindex // 128) % 2048) * 4096
+ ((xindex // 8388608))) Seens we have violate our regulation that # xindex = [0, 1, 2, ....]
tmp1 = tl.load (in_ptr1 + ((xindex // 1) % 128 ) * 1
+ ((xindex // 128) % 2048) * 4096
+ ((xindex // (128*2048)) % 32 ) * 128
+ ((xindex // 8388608))) Now, it follows the rule. we can represent this in form of
The |
Beta Was this translation helpful? Give feedback.
-
@colawithsauce Thanks for the replies. It will take me some time to digest it all. I started another thread here to ask a different question. In your attachments, specifically the TestBroadcast_TwoAxis file, there are two triton kernels generated by torch inductor. Both of the kernels have |
Beta Was this translation helpful? Give feedback.
-
@colawithsauce Hey sorry for not getting back to you last week. I have not had time to fully digest your formulas, but I think I'm able to understand it at a high level. Let me know if the following is correct. So, the gist of the problem here is that even though the triton IR is loading a 1d tensor, we are able to describe this 1d tensor using a combination of sizes, strides, and offsets from your formula which would resemble a 2d tensor. Now, from all of your pytorch code, it looks like all of these are pretty basic operations (implicit broadcast, reduce,...). So we are definitely interested in having support for these cases. You mentioned that your group is working on an implementation already, that is great! We would appreciate the contribution here to make triton-shared more complete and robust. One technical suggestion I have is because all of the code in |
Beta Was this translation helpful? Give feedback.
-
@colawithsauce torch-inductor is getting some improvements in their codegen and won't generate as many div and mod operations as before. I haven't tried it out yet but thought you might be interested in: pytorch/pytorch#125077. There's a related discussion over at #16 too. |
Beta Was this translation helpful? Give feedback.
-
Proposal: Support a new data access pattern.
TOC
the pattern is
When PtrAnalysis,
addState
doesn't support the situation that both two operands have modulo. And We think shapes of tensor in this case can also be somehow determined statically.When will we adding two modulo state? Many cases, it is in the case that the input data is a high dimensional tensor, and user accessing it by pattern
For each dimension, its pointer arith follows patterns:
where
xindex
must be an arange array (for example:[0, 1, 2, 3, 4, ...]
),size
is size of this dimension,stride
is how many elements should be skip to fetch next element in this dimension. Andnum
remove lower dimensional information,num{i} >= num{i-1} * size{i-1}
andnum{i} == size{i-1} * k, k == 0, 1, 2, 3, ...
,a_ptr
is a ptr or a tensor of ptr (and most of time, is simply a pointer, and implicitly boardcasted).If an accessing pattern follows the regulation we give before, we can call this pattern is lirregular (just a name for convinience). Here is example and counter example of this pattern:
Above is an example of 'irregular', every permutation will exist once in this case, and we can simply structure this method with a two-dimensional matrix, as we'll show latter in this chapter.
Above is an conter-example. this picture shown addState of two modulo state, and the second modulo is so random, so this is hard to structured.
The 'ttir' generated by this data access pattern might looks like below:
there are two modulo (remsi) operation in "ttir", and which were asserted to failed. Data access pattern can be represented by:
However, we can change our vision of this transformation from above picture into the following picture.
In this picture,
ptr
of the load operation was considered as an 2D-tensor in logical (Although 1D in physical). Second operand of these twoarith.remsi
operation (%2
,%3
) indicates size of each dimension.Are this pattern normal in real world programming?
We think it is at least normal in
torch.compile
generated triton code (and at least to our usecase). We test on some pytorch codes[1], and find that the operands of load operations in triton-DSL, which generated bytorch.compile
, are highly structured.Here are some examples:
Here seems an counter-example, however notice the
load
operation: they doesn't be used together, but used in separateload
operation instead.Conclusion
In conclusion, we propose an pattern that is common in real world programming, which hasn't been supported by triton-shared. And we define this pattern and meaning of its. And we also gives some code example for it.
In our opinion, this pattern is common and implementable. we are wondering that if your team have explore this idea and found it is unimplementable? or is there something that we doesn't noticed? or is there unclear on our post? We are desire to kown your opinion on this idea. And we are now working on this method and try to writing some code to add support of this pattern.
Thanks for your attention!
attachments
[1]: torch codes
TestBroadcast_TwoAxis.pdf
TestReduce.pdf
TestPermute.pdf
TestBroadcast.pdf
Beta Was this translation helpful? Give feedback.
All reactions