Skip to content

Commit

Permalink
common generation changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Cemberk committed Feb 3, 2025
1 parent 7ac7d58 commit e8d39c1
Showing 1 changed file with 33 additions and 19 deletions.
52 changes: 33 additions & 19 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2546,21 +2546,33 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
"return_tensors": "pt",
}

@skipIfRocm(arch='gfx1201')
@skipIfRocm(arch=['gfx1201','gfx942','gfx90a','gfx1100'])
def test_custom_logits_processor(self):
super().test_custom_logits_processor()
pass

@skipIfRocm(arch='gfx1201')
@skipIfRocm(arch=['gfx1201','gfx942','gfx90a','gfx1100'])
def test_max_new_tokens_encoder_decoder(self):
super().test_max_new_tokens_encoder_decoder()
pass

@skipIfRocm(arch='gfx1201')
@skipIfRocm(arch=['gfx1201','gfx942','gfx90a','gfx1100'])
def test_eos_token_id_int_and_list_beam_search(self):
super().test_eos_token_id_int_and_list_beam_search()
pass

@skipIfRocm(arch='gfx942')
def test_transition_scores_greedy_search_normalized(self):
super().test_transition_scores_greedy_search_normalized()

@skipIfRocm(arch='gfx942')
def test_transition_scores_greedy_search(self):
super().test_transition_scores_greedy_search()

@skipIfRocm(arch='gfx942')
def test_generate_input_features_as_encoder_kwarg(self):
super().test_generate_input_features_as_encoder_kwarg()

@slow
def test_diverse_beam_search(self):
# PT-only test: TF doesn't have a diverse beam search implementation
Expand Down Expand Up @@ -2596,6 +2608,7 @@ def test_diverse_beam_search(self):
],
)

@skipIfRocm(arch='gfx942')
def test_max_length_if_input_embeds(self):
# PT-only test: TF doesn't have StoppingCriteria
article = "Today a dragon flew over Paris."
Expand All @@ -2610,7 +2623,7 @@ def test_max_length_if_input_embeds(self):
out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, max_length=max_length)
self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1])

@skipIfRocm(arch='gfx1201')
@skipIfRocm(arch=['gfx1201','gfx942','gfx90a','gfx1100'])
def test_min_length_if_input_embeds(self):
# PT-only test: TF doesn't have StoppingCriteria
article = "Today a dragon flew over Paris."
Expand Down Expand Up @@ -2663,7 +2676,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa
)

# TODO (joao): replace `stop_sequence` in the pipeline by the more recent `generate` functionality
@skipIfRocm(arch='gfx1201')
@skipIfRocm(arch=['gfx1201','gfx942','gfx90a','gfx1100'])
def test_stop_sequence_stopping_criteria(self):
# PT-only test: TF doesn't have StoppingCriteria
prompt = """Hello I believe in"""
Expand Down Expand Up @@ -3199,6 +3212,7 @@ def test_decoder_start_id_from_config(self):
with self.assertRaises(ValueError):
outputs = bart_model.generate(input_ids, generation_config=GenerationConfig(do_sample=False))

@skipIfRocm(arch='gfx942')
def test_contrastive_search_batched(self):
# PT-only test: TF doesn't have constrained beam search
# Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs)
Expand Down Expand Up @@ -3246,7 +3260,7 @@ def test_logits_processor_not_inplace(self):
self.assertListEqual(out.logits[-1].tolist(), out.scores[-1].tolist())
self.assertNotEqual(out_with_temp.logits[-1].tolist(), out_with_temp.scores[-1].tolist())

@skipIfRocm(arch='gfx1201')
@skipIfRocm(arch=['gfx1201','gfx942','gfx90a','gfx1100'])
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
# Has TF equivalent: this test relies on random sampling
generation_kwargs = {
Expand Down Expand Up @@ -3275,7 +3289,7 @@ def test_eos_token_id_int_and_list_top_k_top_sampling(self):
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))

@skipIfRocm(arch='gfx1201')
@skipIfRocm(arch=['gfx1201','gfx942','gfx90a','gfx1100'])
def test_model_kwarg_encoder_signature_filtering(self):
# Has TF equivalent: ample use of framework-specific code
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
Expand Down Expand Up @@ -3313,7 +3327,7 @@ def forward(self, input_ids, **kwargs):
# FakeEncoder.forward() accepts **kwargs -> no filtering -> type error due to unexpected input "foo"
bart_model.generate(input_ids, foo="bar")

@skipIfRocm(arch='gfx1201')
@skipIfRocm(arch=['gfx1201','gfx942','gfx90a','gfx1100'])
def test_default_max_length_warning(self):
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
Expand Down Expand Up @@ -3371,7 +3385,7 @@ def test_default_assisted_generation(self):
self.assertEqual(config.assistant_confidence_threshold, 0.4)
self.assertEqual(config.is_assistant, False)

@skipIfRocm(arch='gfx1201')
@skipIfRocm(arch=['gfx1201','gfx942','gfx90a','gfx1100'])
def test_generated_length_assisted_generation(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
Expand Down Expand Up @@ -3400,7 +3414,7 @@ def test_generated_length_assisted_generation(self):
)
self.assertTrue((input_length + 10) <= out.shape[-1] <= 20)

@skipIfRocm(arch='gfx1201')
@skipIfRocm(arch=['gfx1201','gfx942','gfx90a','gx1100'])
def test_model_kwarg_assisted_decoding_decoder_only(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
Expand Down Expand Up @@ -3435,7 +3449,7 @@ def test_model_kwarg_assisted_decoding_decoder_only(self):
)
self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist())

@skipIfRocm(arch='gfx1201')
@skipIfRocm(arch=['gfx1201','gfx942','gfx90a','gfx1100'])
def test_model_kwarg_assisted_decoding_encoder_decoder(self):
"""
Tests that the following scenario is compatible with assisted generation:
Expand Down Expand Up @@ -3502,7 +3516,7 @@ def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())

@skipIfRocm(arch='gfx1201')
@skipIfRocm(arch=['gfx1201','gfx942','gfx90a','gfx1100'])
def test_assisted_decoding_encoder_decoder_shared_encoder(self):
"""
Tests that the following scenario is compatible with assisted generation:
Expand Down Expand Up @@ -3581,7 +3595,7 @@ def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())

@skipIfRocm(arch='gfx1201')
@skipIfRocm(arch=['gfx1201','gfx942','gfx90a','gfx1100'])
def test_assisted_decoding_num_assistant_tokens_heuristic_schedule(self):
# This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly.

Expand Down Expand Up @@ -3830,7 +3844,7 @@ def test_special_tokens_fall_back_to_model_default(self):
self.assertTrue(test_bos_id == gen_output[0, 0])
self.assertTrue(generation_config.bos_token_id is None)

@skipIfRocm(arch='gfx1201')
@skipIfRocm(arch=['gfx1201','gfx942','gfx90a'])
def test_speculative_decoding_equals_regular_decoding(self):
draft_name = "double7/vicuna-68m"
target_name = "Qwen/Qwen2-0.5B-Instruct"
Expand Down Expand Up @@ -3861,7 +3875,7 @@ def test_speculative_decoding_equals_regular_decoding(self):

@pytest.mark.generate
@require_torch_multi_gpu
@skipIfRocm(arch='gfx1201')
@skipIfRocm(arch=['gfx1201','gfx942','gfx90a'])
def test_generate_with_static_cache_multi_gpu(self):
"""
Tests if the static cache has been set correctly and if generate works correctly when we are using multi-gpus.
Expand Down Expand Up @@ -3897,7 +3911,7 @@ def test_generate_with_static_cache_multi_gpu(self):

@pytest.mark.generate
@require_torch_multi_gpu
@skipIfRocm(arch='gfx1201')
@skipIfRocm(arch=['gfx1201','gfx942','gfx90a'])
def test_init_static_cache_multi_gpu(self):
"""
Tests if the static cache has been set correctly when we initialize it manually in a multi-gpu setup.
Expand Down Expand Up @@ -4079,7 +4093,7 @@ def test_padding_input_contrastive_search_t5(self):
self.assertEqual(generated_text_no_padding, generated_text_with_padding)
self.assertEqual(generated_text_no_padding, "Ich muss diese Aufgabe vor Ende des Tages beenden.")

@skipIfRocm(arch='gfx1201')
@skipIfRocm(arch=['gfx1201','gfx942'])
def test_prepare_inputs_for_generation_decoder_llm(self):
"""Tests GenerationMixin.prepare_inputs_for_generation against expected usage with decoder-only llms."""

Expand Down Expand Up @@ -4196,7 +4210,7 @@ def test_prepare_inputs_for_generation_encoder_decoder_llm(self):
self.assertTrue(model_inputs["encoder_outputs"] == "foo")
# See the decoder-only test for more corner cases. The code is the same, so we don't repeat it here.

@skipIfRocm(arch='gfx1201')
@skipIfRocm(arch=['gfx1201','gfx942','gfx90a'])
def test_generate_compile_fullgraph_tiny(self):
"""
Tests that we can call end-to-end generation with a tiny model (i.e. doesn't crash)
Expand All @@ -4220,7 +4234,7 @@ def test_generate_compile_fullgraph_tiny(self):
gen_out = compiled_generate(**model_inputs, generation_config=generation_config)
self.assertTrue(gen_out.shape[1] > model_inputs["input_ids"].shape[1]) # some text was generated

@skipIfRocm(arch='gfx1201')
@skipIfRocm(arch=['gfx1201','gfx942','gfx90a','gfx1100'])
def test_assisted_generation_early_exit(self):
"""
Tests that assisted generation with early exit works as expected. Under the hood, this has complex cache
Expand Down

0 comments on commit e8d39c1

Please sign in to comment.