Skip to content

Commit

Permalink
[fix] Extend when a model repository/directory already has an expor…
Browse files Browse the repository at this point in the history
…ted OV model (#1000)

* Also accept e.g. "openvino_model_qint8_quantized.xml"

* Add test case

* Add missing subfolder call to test
  • Loading branch information
tomaarsen authored Nov 18, 2024
1 parent c7d6227 commit e3031f0
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
2 changes: 1 addition & 1 deletion optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def from_pretrained(

ov_files = _find_files_matching_pattern(
model_dir,
pattern=r"(.*)?openvino(.*)?\_model.xml$",
pattern=r"(.*)?openvino(.*)?\_model(.*)?.xml$",
subfolder=subfolder,
use_auth_token=token,
revision=revision,
Expand Down
24 changes: 21 additions & 3 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def test_infer_export_when_loading(self):

def test_find_files_matching_pattern(self):
model_id = "echarlaix/tiny-random-PhiForCausalLM"
pattern = r"(.*)?openvino(.*)?\_model.xml$"
pattern = r"(.*)?openvino(.*)?\_model(.*)?.xml$"
# hub model
for revision in ("main", "ov", "itrex"):
ov_files = _find_files_matching_pattern(
Expand All @@ -452,7 +452,7 @@ def test_find_files_matching_pattern(self):

@parameterized.expand(("stable-diffusion", "stable-diffusion-openvino"))
def test_find_files_matching_pattern_sd(self, model_arch):
pattern = r"(.*)?openvino(.*)?\_model.xml$"
pattern = r"(.*)?openvino(.*)?\_model(.*)?.xml$"
model_id = MODEL_NAMES[model_arch]
# hub model
ov_files = _find_files_matching_pattern(model_id, pattern=pattern)
Expand All @@ -470,7 +470,7 @@ def test_find_files_matching_pattern_sd(self, model_arch):
def test_find_files_matching_pattern_with_config_in_root(self, subfolder):
# Notably, the model has a config.json file in the root directory and not in the subfolder
model_id = "sentence-transformers-testing/stsb-bert-tiny-openvino"
pattern = r"(.*)?openvino(.*)?\_model.xml$"
pattern = r"(.*)?openvino(.*)?\_model(.*)?.xml$"
# hub model
ov_files = _find_files_matching_pattern(model_id, pattern=pattern, subfolder=subfolder)
self.assertTrue(len(ov_files) == 1 if subfolder == "openvino" else len(ov_files) == 0)
Expand All @@ -483,6 +483,24 @@ def test_find_files_matching_pattern_with_config_in_root(self, subfolder):
ov_files = _find_files_matching_pattern(local_dir, pattern=pattern, subfolder=subfolder)
self.assertTrue(len(ov_files) == 1 if subfolder == "openvino" else len(ov_files) == 0)

def test_find_files_matching_pattern_with_quantized_ov_model(self):
# This model only has "openvino/openvino_model_qint8_quantized.xml" and "openvino/openvino_model_qint8_quantized.bin"
# We want to ensure that this model is found, so the `export` isn't forced to True
model_id = "sentence-transformers-testing/stsb-bert-tiny-openvino-quantized-only"
subfolder = "openvino"
pattern = r"(.*)?openvino(.*)?\_model(.*)?.xml$"
# hub model
ov_files = _find_files_matching_pattern(model_id, pattern=pattern, subfolder=subfolder)
self.assertTrue(len(ov_files) == 1)

# local model
api = HfApi()
with tempfile.TemporaryDirectory() as tmpdirname:
local_dir = Path(tmpdirname) / "model"
api.snapshot_download(repo_id=model_id, local_dir=local_dir)
ov_files = _find_files_matching_pattern(local_dir, pattern=pattern, subfolder=subfolder)
self.assertTrue(len(ov_files) == 1)


class PipelineTest(unittest.TestCase):
def test_load_model_from_hub(self):
Expand Down

0 comments on commit e3031f0

Please sign in to comment.