Skip to content

Commit

Permalink
catch errors
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Dec 16, 2024
1 parent 307df81 commit 7336fce
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions optimum_benchmark/backends/py_txi/backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Dict, List, Union
Expand Down Expand Up @@ -36,7 +36,10 @@ def load(self) -> None:
self.logger.info("\t+ Loading pretrained model")
self.load_model_from_pretrained()

self.tmpdir.cleanup()
try:
self.tmpdir.cleanup()
except Exception:
shutil.rmtree(self.tmpdir.name)

def download_pretrained_model(self) -> None:
model_snapshot_folder = snapshot_download(self.config.model, **self.config.model_kwargs)
Expand All @@ -49,6 +52,7 @@ def download_pretrained_model(self) -> None:
def create_no_weights_model(self) -> None:
model_path = Path(hf_hub_download(self.config.model, filename="config.json", cache_dir=self.tmpdir.name)).parent
save_model(model=torch.nn.Linear(1, 1), filename=model_path / "model.safetensors", metadata={"format": "pt"})

self.pretrained_processor.save_pretrained(save_directory=model_path)
self.pretrained_config.save_pretrained(save_directory=model_path)

Expand All @@ -57,6 +61,7 @@ def create_no_weights_model(self) -> None:
self.pretrained_model = self.automodel_loader.from_pretrained(
model_path, **self.config.model_kwargs, device_map="auto", _fast_init=False
)

save_model(model=self.pretrained_model, filename=model_path / "model.safetensors", metadata={"format": "pt"})
del self.pretrained_model
torch.cuda.empty_cache()
Expand Down

0 comments on commit 7336fce

Please sign in to comment.