Skip to content

Commit

Permalink
fix bugs with HF interface (#98)
Browse files Browse the repository at this point in the history
Co-authored-by: morgoth95 <[email protected]>
  • Loading branch information
diegofiori and morgoth95 authored Sep 12, 2022
1 parent 9b4d050 commit d013d2c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
2 changes: 1 addition & 1 deletion nebullvm/api/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def convert_hf_model(
tokenizer=tokenizer,
tokenizer_args=tokenizer_args,
)
input_example = tokenizer(input_data)
input_example = tokenizer(input_data, **tokenizer_args)
input_data = _HFTextDataset(
input_texts=input_data,
ys=ys,
Expand Down
15 changes: 13 additions & 2 deletions nebullvm/compressors/sparseml.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import logging
from logging import Logger
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Callable, Tuple, Optional, Dict
Expand Down Expand Up @@ -29,10 +31,19 @@ def _load_with_torch_fx(path: Path):
return model


def _save_model(model: torch.nn.Module, path: Path):
def _save_model(model: torch.nn.Module, path: Path, logger: Logger = None):
try:
_save_with_torch_fx(model, path)
except Exception:
except Exception as ex:
message = (
f"Got an error while exporting with TorchFX. The model will be "
f"saved using the standard PyTorch save pickling method. Error "
f"got: {ex}"
)
if logger is None:
logging.warning(message)
else:
logger.warning(message)
torch.save(model, path / "model.pt")
return path / "model.pt"
else:
Expand Down

0 comments on commit d013d2c

Please sign in to comment.