Skip to content

Fixed the llama model #769

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions test/test_ao_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest
import torch
from torchao._models.llama.model import Transformer

_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])


def init_model(name="stories15M", device="cpu", precision=torch.bfloat16):
model = Transformer.from_name(name)
model.to(device=device, dtype=precision)
return model.eval()


@pytest.mark.parametrize("device", _AVAILABLE_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("is_training", [True, False])
def test_ao_llama_model_inference_mode(device, batch_size, is_training):
random_model = init_model(device=device)
seq_len = 16
input_ids = torch.randint(0, 1024, (batch_size, seq_len)).to(device)
input_pos = None if is_training else torch.arange(seq_len).to(device)
with torch.device(device):
random_model.setup_caches(max_batch_size=batch_size, max_seq_length=seq_len, training=is_training)
for i in range(3):
out = random_model(input_ids, input_pos)
assert out is not None, "model failed to run"
30 changes: 22 additions & 8 deletions torchao/_models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,19 +193,33 @@ def setup_caches(self, max_batch_size, max_seq_length, training: bool=False, kv_


def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
"""Forward pass of the model.

Args:
idx (`torch.LongTensor` of shape `(batch_size, seq_length)`):
Indices of input sequence tokens in the vocabulary.
input_pos (`torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings.
This argument is optional for training mode but required for
inference mode(when model.setup_caches(training=False) is used).

Returns:
Tensor: The output logits tensor.
"""
assert self.freqs_cis is not None, "Caches must be initialized first"

if input_pos is None:
mask = None
freqs_cis = self.freqs_cis[:idx.shape[1]]
elif not self.linear_causal_mask:
mask = self.causal_mask[None, None, input_pos]
elif len(input_pos)>1 and self.linear_causal_mask: # prefill for linear causal mask
mask = torch.tril(torch.ones(len(input_pos), self.max_seq_length, dtype=torch.bool, device=input_pos.device)).unsqueeze(0).unsqueeze(0)
else: # decode_one_token for linear causal mask
self.causal_mask[0,0,0,input_pos] = 1
mask = self.causal_mask
freqs_cis = self.freqs_cis[input_pos]
else:
if not self.linear_causal_mask:
mask = self.causal_mask[None, None, input_pos]
elif len(input_pos)>1 and self.linear_causal_mask: # prefill for linear causal mask
mask = torch.tril(torch.ones(len(input_pos), self.max_seq_length, dtype=torch.bool, device=input_pos.device)).unsqueeze(0).unsqueeze(0)
else: # decode_one_token for linear causal mask
self.causal_mask[0,0,0,input_pos] = 1
mask = self.causal_mask
freqs_cis = self.freqs_cis[input_pos]

x = self.tok_embeddings(idx)

Expand Down
Loading