Skip to content

Commit

Permalink
testing change
Browse files Browse the repository at this point in the history
  • Loading branch information
Cemberk committed Feb 20, 2025
1 parent 269c793 commit 2d29bc7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
8 changes: 8 additions & 0 deletions tests/models/grounding_dino/test_modeling_grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,14 @@ class GroundingDinoModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Tes
else {}
)

@skipIfRocm
def test_batching_equivalence(self):
super().test_batching_equivalence()

@skipIfRocm
def test_training(self):
super().test_training()

# special case for head models
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
Expand Down
3 changes: 3 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,15 +917,18 @@ def test_training(self):
loss = model(**inputs).loss
loss.backward()

@skipIfRocm(arch=['gfx90a','gfx942'])
def test_training_gradient_checkpointing(self):
# Scenario - 1 default behaviour
self.check_training_gradient_checkpointing()

@skipIfRocm(arch=['gfx90a','gfx942'])
def test_training_gradient_checkpointing_use_reentrant(self):
# Scenario - 2 with `use_reentrant=True` - this is the default value that is used in pytorch's
# torch.utils.checkpoint.checkpoint
self.check_training_gradient_checkpointing(gradient_checkpointing_kwargs={"use_reentrant": True})

@skipIfRocm(arch=['gfx90a','gfx942'])
def test_training_gradient_checkpointing_use_reentrant_false(self):
# Scenario - 3 with `use_reentrant=False` pytorch suggests users to use this value for
# future releases: https://pytorch.org/docs/stable/checkpoint.html
Expand Down

0 comments on commit 2d29bc7

Please sign in to comment.