diff --git a/test/test_ao_models.py b/test/test_ao_models.py new file mode 100644 index 0000000000..6680000be9 --- /dev/null +++ b/test/test_ao_models.py @@ -0,0 +1,26 @@ +import pytest +import torch +from torchao._models.llama.model import Transformer + +_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + + +def init_model(name="stories15M", device="cpu", precision=torch.bfloat16): + model = Transformer.from_name(name) + model.to(device=device, dtype=precision) + return model.eval() + + +@pytest.mark.parametrize("device", _AVAILABLE_DEVICES) +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("is_training", [True, False]) +def test_ao_llama_model_inference_mode(device, batch_size, is_training): + random_model = init_model(device=device) + seq_len = 16 + input_ids = torch.randint(0, 1024, (batch_size, seq_len)).to(device) + input_pos = None if is_training else torch.arange(seq_len).to(device) + with torch.device(device): + random_model.setup_caches(max_batch_size=batch_size, max_seq_length=seq_len, training=is_training) + for i in range(3): + out = random_model(input_ids, input_pos) + assert out is not None, "model failed to run" diff --git a/torchao/_models/llama/model.py b/torchao/_models/llama/model.py index ab3a51eef3..2f6ac5cb50 100644 --- a/torchao/_models/llama/model.py +++ b/torchao/_models/llama/model.py @@ -193,19 +193,33 @@ def setup_caches(self, max_batch_size, max_seq_length, training: bool=False, kv_ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + """Forward pass of the model. + + Args: + idx (`torch.LongTensor` of shape `(batch_size, seq_length)`): + Indices of input sequence tokens in the vocabulary. + input_pos (`torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. + This argument is optional for training mode but required for + inference mode(when model.setup_caches(training=False) is used). + + Returns: + Tensor: The output logits tensor. + """ assert self.freqs_cis is not None, "Caches must be initialized first" if input_pos is None: mask = None freqs_cis = self.freqs_cis[:idx.shape[1]] - elif not self.linear_causal_mask: - mask = self.causal_mask[None, None, input_pos] - elif len(input_pos)>1 and self.linear_causal_mask: # prefill for linear causal mask - mask = torch.tril(torch.ones(len(input_pos), self.max_seq_length, dtype=torch.bool, device=input_pos.device)).unsqueeze(0).unsqueeze(0) - else: # decode_one_token for linear causal mask - self.causal_mask[0,0,0,input_pos] = 1 - mask = self.causal_mask - freqs_cis = self.freqs_cis[input_pos] + else: + if not self.linear_causal_mask: + mask = self.causal_mask[None, None, input_pos] + elif len(input_pos)>1 and self.linear_causal_mask: # prefill for linear causal mask + mask = torch.tril(torch.ones(len(input_pos), self.max_seq_length, dtype=torch.bool, device=input_pos.device)).unsqueeze(0).unsqueeze(0) + else: # decode_one_token for linear causal mask + self.causal_mask[0,0,0,input_pos] = 1 + mask = self.causal_mask + freqs_cis = self.freqs_cis[input_pos] x = self.tok_embeddings(idx)