Candle equivalent to masked_scatter? #2208
-
Hi, I'm currently porting some Pytorch code over to Candle, and ran into a stumbling block. The source makes use of final_embedding = final_embedding.masked_scatter(
image_mask.unsqueeze(-1).expand_as(final_embedding), scaled_image_features
) shape of final embedding: shape of image_mask shape of scaled features I would also be happy to route around this issue using a similar equivalent. Example of masked_scatter: self = torch.tensor([1, 2, 3, 4, 5, 6, 7])
mask = torch.tensor([1, 0, 0, 1, 0, 1, 1], dtype=torch.bool)
src = torch.tensor([100, 200, 300, 400])
self.masked_scatter(mask, src)
# => tensor([100, 2, 3, 200, 5, 300, 400]) |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
cuda kernel looks to be this: CPU kernel looks to be this: |
Beta Was this translation helpful? Give feedback.
-
If the mask has a "regular" shape (e.g. the image embeddings are before or after the text ones in some dimensions), maybe you can just use a |
Beta Was this translation helpful? Give feedback.
If the mask has a "regular" shape (e.g. the image embeddings are before or after the text ones in some dimensions), maybe you can just use a
Tensor::cat
instead?Otherwise I would suggest processing this manually on the cpu to start with and get sure that the model works well.
Longer term, maybe add a custom op or maybe simpler: use
index_select
where the indexes are computed based on the mask itself + multiply by a 0/1 mask to ensure that only the relevant bits are kept.