Skip to content

Commit

Permalink
Put back in EOS=2 case as it actually gets hit
Browse files Browse the repository at this point in the history
  • Loading branch information
sogartar committed Dec 9, 2024
1 parent cb0b703 commit 0d30d62
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions sharktank/sharktank/models/clip/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,14 +427,28 @@ def forward(
)
last_hidden_state = self.final_layer_norm(last_hidden_state)

# We don't support this variant.
assert self.eos_token_id != 2

pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0]),
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
(input_ids == self.eos_token_id).int().argmax(dim=-1),
]
if self.eos_token_id == 2:
# The `eos_token_id` was incorrect before PR
# https://github.com/huggingface/transformers/pull/24773
# Let's keep what have been done here.
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
# ------------------------------------------------------------
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0]),
input_ids.argmax(dim=-1),
]
else:
# The config gets updated `eos_token_id` from PR
# https://github.com/huggingface/transformers/pull/24773
# (so the use of exta new tokens is possible)
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0]),
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
(input_ids == self.eos_token_id).int().argmax(dim=-1),
]

if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
Expand Down

0 comments on commit 0d30d62

Please sign in to comment.