Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Jul 30, 2024
1 parent ebd8df3 commit 0344b62
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
2 changes: 1 addition & 1 deletion optimum/intel/utils/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _find_files_matching_pattern(
token = use_auth_token

library_name = TasksManager.infer_library_from_model(
model_name_or_path, subfolder=subfolder, revision=revision, token=token
str(model_name_or_path), subfolder=subfolder, revision=revision, token=token
)
if library_name == "diffusers":
subfolder = os.path.join(subfolder, "unet")
Expand Down
17 changes: 8 additions & 9 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,23 +281,22 @@ def test_find_files_matching_pattern(self):
model_id = "echarlaix/tiny-random-PhiForCausalLM"
pattern = r"(.*)?openvino(.*)?\_model.xml"
# hub model
for revision in ("main", "ov"):
ov_files = _find_files_matching_pattern(model_id, pattern=pattern, revision=revision)
for revision in ("main", "ov", "itrex"):
ov_files = _find_files_matching_pattern(
model_id, pattern=pattern, revision=revision, subfolder="openvino" if revision == "itrex" else ""
)
self.assertTrue(len(ov_files) == 0 if revision == "main" else len(ov_files) > 0)
ov_files = _find_files_matching_pattern(model_id, pattern=pattern, subfolder="openvino")
self.assertTrue(len(ov_files) > 0)

# local model
api = HfApi()
with tempfile.TemporaryDirectory() as tmpdirname:
for revision in ("main", "ov"):
for revision in ("main", "ov", "itrex"):
local_dir = Path(tmpdirname) / revision
api.snapshot_download(repo_id=model_id, local_dir=local_dir, revision=revision)
ov_files = _find_files_matching_pattern(local_dir, pattern=pattern, revision=revision)
ov_files = _find_files_matching_pattern(
local_dir, pattern=pattern, revision=revision, subfolder="openvino" if revision == "itrex" else ""
)
self.assertTrue(len(ov_files) == 0 if revision == "main" else len(ov_files) > 0)
if revision == "main":
ov_files = _find_files_matching_pattern(local_dir, pattern=pattern, subfolder="openvino")
self.assertTrue(len(ov_files) > 0)

@parameterized.expand(("stable-diffusion", "stable-diffusion-openvino"))
def test_find_files_matching_pattern_sd(self, model_arch):
Expand Down

0 comments on commit 0344b62

Please sign in to comment.