Skip to content

Commit

Permalink
partial skips
Browse files Browse the repository at this point in the history
  • Loading branch information
Cemberk committed Jan 31, 2025
1 parent 4b439fe commit 1ec92df
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 4 deletions.
4 changes: 2 additions & 2 deletions tests/models/granitemoe/test_modeling_granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,12 @@ class GraniteMoeModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.7, 0.8]

@skipIfRocm(arch='gfx1201')
@skipIfRocm(arch=['gfx1201','gfx942'])
def test_generate_with_static_cache(self):
super().test_generate_with_static_cache()
pass

@skipIfRocm(arch='gfx1201')
@skipIfRocm(arch=['gfx1201','gfx942'])
def test_generate_from_inputs_embeds_with_static_cache(self):
super().test_generate_from_inputs_embeds_with_static_cache()
pass
Expand Down
2 changes: 1 addition & 1 deletion tests/models/hubert/test_modeling_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)

@skipIfRocm(arch='gfx90a')
@skipIfRocm(arch=['gfx90a','gfx942'])
def test_batched_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_batch_inference(*config_and_inputs)
Expand Down
3 changes: 2 additions & 1 deletion tests/models/mamba2/test_modeling_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from parameterized import parameterized

from transformers import AutoTokenizer, Mamba2Config, is_torch_available
from transformers.testing_utils import require_read_token, require_torch, require_torch_gpu, slow, torch_device
from transformers.testing_utils import require_read_token, require_torch, require_torch_gpu, slow, torch_device, skipIfRocm
from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available

from ...generation.test_utils import GenerationTesterMixin
Expand Down Expand Up @@ -233,6 +233,7 @@ def setUp(self):
self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"]
)

@skipIfRocm(arch='gfx942')
def test_mamba2_caching(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mamba2_caching(*config_and_inputs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
require_torch_sdpa,
slow,
torch_device,
skipIfRocm,
)

from ...test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
Expand Down Expand Up @@ -503,6 +504,10 @@ def test_sdpa_can_dispatch_composite_models(self):

@require_torch
class Wav2Vec2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
@skipIfRocm(arch='gfx942')
def test_save_and_load_from_pretrained(self):
super().test_save_and_load_from_pretrained()

def get_pretrained_model_and_inputs(self):
model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
"facebook/wav2vec2-base-960h", "google-bert/bert-base-cased"
Expand Down

0 comments on commit 1ec92df

Please sign in to comment.