Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The answer is given to SLAM as input during training, wouldn't that make the model cheat? #132

Open
1 of 2 tasks
Yahya-Saleh opened this issue Sep 19, 2024 · 1 comment
Open
1 of 2 tasks

Comments

@Yahya-Saleh
Copy link

Yahya-Saleh commented Sep 19, 2024

System Info

PyTorch version: 2.4.0+cu121

Information

  • The official example scripts
  • My own modified scripts

🐛 Describe the bug

I was fine-tuning SLAM on my own dataset using by modify one of the example bash scripts provided: finetune_wavlm_large_linear_vicuna_7b.sh. I got good evaluation results but were struggling to reproduce those evaluation results using the inference script: decode_hubert_xtralarge_linear_vicuna_7b.sh.

After some digging I noticed that during training and evaluation the expected answer, the ground truth transcription in the asr case, is concatenated with the prompt and passed as input: example = prompt + answer -> "input_ids": example_ids.

https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/datasets/speech_dataset.py

        answer = self.answer_template.format(target)
        example = prompt + answer  # FIX(MZY): avoid putting a bos token before answer.
        example_ids = self.tokenizer.encode(example)  # [prompt,answer]
        example_ids.append(self.tokenizer.eos_token_id)  # [prompt,answer,eos]
        example_ids = torch.tensor(
            example_ids, dtype=torch.int64
        )
        example_ids = torch.cat((audio_pseudo, example_ids))  # [audio,prompt,answer,eos]

        labels_ids = copy.deepcopy(example_ids)  # [audio,prompt,answer,eos]
        labels_ids[:audio_length + prompt_length] = -1  # [-1,-1,answer,eos];
        example_mask = example_ids.ge(-1)  # FIX(GZF): [True,True,True,True]

        label_mask = labels_ids.ge(0)  # [False,False,True,True]
        example_ids[~example_mask] = 0  # [audio,prompt,answer,eos]
        labels_ids[~label_mask] = self.IGNORE_INDEX  # [-100,-100,answer,eos]

        return {
            "input_ids": example_ids,
            "labels": labels_ids,
            "attention_mask": example_mask,
            "audio": audio_raw if self.input_type == "raw" else None,
            "audio_mel": audio_mel if self.input_type == "mel" else None,
            "audio_length": audio_length,
            "prompt_length": prompt_length,
        }
     ```
     
     If the answer is passed as input, wouldn't that get the model to learn to repeat the last part of the input instead of performing asr? 
     
     Thank you in advance for the help!


### Error logs

there was no error, just unexpected behavior, where the evaluation and inference results differ.

### Expected behavior

I would expect the input during training would only be the prompt without the answer to the prompt.
@world1tree
Copy link

if self.inference_mode:
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64)
example_ids = torch.cat((audio_pseudo, prompt_ids)) # [audio,prompt]
example_mask = example_ids.ge(-1) # [True,True]
return {
"input_ids": example_ids,
"attention_mask": example_mask,
"audio": audio_raw if self.input_type == "raw" else None,
"audio_mel": audio_mel if self.input_type == "mel" else None,
"audio_length": audio_length,
"key": key,
"target": target,
"prompt_length": prompt_length,
}
answer = self.answer_template.format(target)
example = prompt + answer # FIX(MZY): avoid putting a bos token before answer.
example_ids = self.tokenizer.encode(example) # [prompt,answer]
example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos]
example_ids = torch.tensor(

As you can see from the above code, it is only necessary to use answer as a supervised signal when training, and no answer is given during evaluation(inference_mode=True) .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants