Skip to content

Commit

Permalink
Make function sample_inputs return an initialized tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
sogartar committed Dec 10, 2024
1 parent b5e4477 commit 6b7abec
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions sharktank/sharktank/models/clip/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,14 +488,20 @@ def set_input_embeddings(self, value):
self.text_model.embeddings.token_embedding = value

def sample_inputs(self, batch_size: int) -> OrderedDict[str, AnyTensor]:
input_ids = (
torch.arange(
start=0,
end=batch_size * self.config.max_position_embeddings,
dtype=torch.long,
)
% self.config.vocab_size
)
input_ids = input_ids.reshape([batch_size, self.config.max_position_embeddings])
return OrderedDict(
[
(
"input_ids",
torch.empty(
size=[batch_size, self.config.max_position_embeddings],
dtype=torch.long,
),
input_ids,
)
]
)
Expand Down

0 comments on commit 6b7abec

Please sign in to comment.