Skip to content

Commit 8c18489

Browse files
authored
Fixed the llama model (#769)
* fixed input_pos is None Signed-off-by: yiliu30 <[email protected]> * add test Signed-off-by: yiliu30 <[email protected]> * update the test Signed-off-by: yiliu30 <[email protected]> * update the docstring Signed-off-by: yiliu30 <[email protected]> * update the docstring Signed-off-by: yiliu30 <[email protected]> --------- Signed-off-by: yiliu30 <[email protected]>
1 parent e15e509 commit 8c18489

File tree

2 files changed

+48
-8
lines changed

2 files changed

+48
-8
lines changed

test/test_ao_models.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import pytest
2+
import torch
3+
from torchao._models.llama.model import Transformer
4+
5+
_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
6+
7+
8+
def init_model(name="stories15M", device="cpu", precision=torch.bfloat16):
9+
model = Transformer.from_name(name)
10+
model.to(device=device, dtype=precision)
11+
return model.eval()
12+
13+
14+
@pytest.mark.parametrize("device", _AVAILABLE_DEVICES)
15+
@pytest.mark.parametrize("batch_size", [1, 4])
16+
@pytest.mark.parametrize("is_training", [True, False])
17+
def test_ao_llama_model_inference_mode(device, batch_size, is_training):
18+
random_model = init_model(device=device)
19+
seq_len = 16
20+
input_ids = torch.randint(0, 1024, (batch_size, seq_len)).to(device)
21+
input_pos = None if is_training else torch.arange(seq_len).to(device)
22+
with torch.device(device):
23+
random_model.setup_caches(max_batch_size=batch_size, max_seq_length=seq_len, training=is_training)
24+
for i in range(3):
25+
out = random_model(input_ids, input_pos)
26+
assert out is not None, "model failed to run"

torchao/_models/llama/model.py

+22-8
Original file line numberDiff line numberDiff line change
@@ -193,19 +193,33 @@ def setup_caches(self, max_batch_size, max_seq_length, training: bool=False, kv_
193193

194194

195195
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
196+
"""Forward pass of the model.
197+
198+
Args:
199+
idx (`torch.LongTensor` of shape `(batch_size, seq_length)`):
200+
Indices of input sequence tokens in the vocabulary.
201+
input_pos (`torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*):
202+
Indices of positions of each input sequence tokens in the position embeddings.
203+
This argument is optional for training mode but required for
204+
inference mode(when model.setup_caches(training=False) is used).
205+
206+
Returns:
207+
Tensor: The output logits tensor.
208+
"""
196209
assert self.freqs_cis is not None, "Caches must be initialized first"
197210

198211
if input_pos is None:
199212
mask = None
200213
freqs_cis = self.freqs_cis[:idx.shape[1]]
201-
elif not self.linear_causal_mask:
202-
mask = self.causal_mask[None, None, input_pos]
203-
elif len(input_pos)>1 and self.linear_causal_mask: # prefill for linear causal mask
204-
mask = torch.tril(torch.ones(len(input_pos), self.max_seq_length, dtype=torch.bool, device=input_pos.device)).unsqueeze(0).unsqueeze(0)
205-
else: # decode_one_token for linear causal mask
206-
self.causal_mask[0,0,0,input_pos] = 1
207-
mask = self.causal_mask
208-
freqs_cis = self.freqs_cis[input_pos]
214+
else:
215+
if not self.linear_causal_mask:
216+
mask = self.causal_mask[None, None, input_pos]
217+
elif len(input_pos)>1 and self.linear_causal_mask: # prefill for linear causal mask
218+
mask = torch.tril(torch.ones(len(input_pos), self.max_seq_length, dtype=torch.bool, device=input_pos.device)).unsqueeze(0).unsqueeze(0)
219+
else: # decode_one_token for linear causal mask
220+
self.causal_mask[0,0,0,input_pos] = 1
221+
mask = self.causal_mask
222+
freqs_cis = self.freqs_cis[input_pos]
209223

210224
x = self.tok_embeddings(idx)
211225

0 commit comments

Comments
 (0)