diff --git a/tests/test_stablediffusion.py b/tests/test_stablediffusion.py index 660c17caea..b87c61c83b 100644 --- a/tests/test_stablediffusion.py +++ b/tests/test_stablediffusion.py @@ -133,10 +133,6 @@ def prepare_inputs_for_testing(self): ) ) def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): - if config_cls in [LoHaConfig, OFTConfig]: - # TODO: This test is flaky with PyTorch 2.1 on Windows, we need to figure out what is going on - self.skipTest("LoHaConfig and OFTConfig test is flaky") - # Instantiate model & adapters model = self.instantiate_sd_peft(model_id, config_cls, config_kwargs) @@ -146,7 +142,9 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): peft_output = np.array(model(**dummy_input).images[0]).astype(np.float32) # Merge adapter and model - model.text_encoder = model.text_encoder.merge_and_unload() + if config_cls not in [LoHaConfig, OFTConfig]: + # TODO: Merging the text_encoder is leading to issues on CPU with PyTorch 2.1 + model.text_encoder = model.text_encoder.merge_and_unload() model.unet = model.unet.merge_and_unload() # Generate output for peft merged StableDiffusion