Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimum model #106

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion run_evals_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,15 @@ def get_parser():
default=False,
help="set to True of your model has been trained with peft, also need to provide the base model name",
)
weight_type_group.add_argument(
"--optimum_weights",
action="store_true",
default=False,
help="set to True of your model needs to be loadded with optimum",
)
parser.add_argument(
"--base_model", type=str, default=None, help="name of the base model to be used for delta or adapter weights"
)

task_type_group.add_argument("--model_args")
parser.add_argument("--model_dtype", type=str, default=None)
parser.add_argument(
Expand Down
13 changes: 13 additions & 0 deletions src/lighteval/models/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
from lighteval.utils import (
NO_AUTOGPTQ_ERROR_MSG,
NO_BNB_ERROR_MSG,
NO_OPTIMUM_ERROR_MSG,
NO_PEFT_ERROR_MSG,
is_accelerate_available,
is_autogptq_available,
is_bnb_available,
is_optimum_available,
is_peft_available,
)

Expand Down Expand Up @@ -190,6 +192,15 @@ def init_configs(self, env_config: EnvConfig):
return self._init_configs(self.base_model, env_config)


@dataclass
class OptimumModelConfig(BaseModelConfig):
def __post_init__(self):
if not is_optimum_available():
raise ImportError(NO_OPTIMUM_ERROR_MSG)

return super().__post_init__()


@dataclass
class TGIModelConfig:
inference_server_address: str
Expand Down Expand Up @@ -298,6 +309,8 @@ def create_model_config(args: Namespace, accelerator: Union["Accelerator", None]
if args.base_model is None:
raise ValueError("You need to specify a base model when using adapter weights")
return AdapterModelConfig(**args_dict)
if args.optimum_weights:
return OptimumModelConfig(**args_dict)
if args.base_model is not None:
raise ValueError("You can't specifify a base model if you are not using delta/adapter weights")
return BaseModelConfig(**args_dict)
17 changes: 14 additions & 3 deletions src/lighteval/models/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
EnvConfig,
InferenceEndpointModelConfig,
InferenceModelConfig,
OptimumModelConfig,
TGIModelConfig,
)
from lighteval.models.optimum_model import OptimumModel
from lighteval.models.tgi_model import ModelClient
from lighteval.utils import NO_TGI_ERROR_MSG, is_accelerate_available, is_tgi_available

Expand All @@ -32,9 +34,16 @@ class ModelInfo:


def load_model( # noqa: C901
config: Union[BaseModelConfig, AdapterModelConfig, DeltaModelConfig, TGIModelConfig, InferenceEndpointModelConfig],
config: Union[
BaseModelConfig,
AdapterModelConfig,
DeltaModelConfig,
TGIModelConfig,
InferenceEndpointModelConfig,
OptimumModelConfig,
],
env_config: EnvConfig,
) -> Tuple[Union[BaseModel, AdapterModel, DeltaModel, ModelClient], ModelInfo]:
) -> Tuple[Union[BaseModel, AdapterModel, DeltaModel, ModelClient, OptimumModel], ModelInfo]:
"""Will load either a model from an inference server or a model from a checkpoint. depending
on the arguments passed to the program.

Expand Down Expand Up @@ -93,12 +102,14 @@ def load_model_with_inference_endpoints(config: InferenceEndpointModelConfig, en


def load_model_with_accelerate_or_default(
config: Union[AdapterModelConfig, BaseModelConfig, DeltaModelConfig], env_config: EnvConfig
config: Union[AdapterModelConfig, BaseModelConfig, DeltaModelConfig, OptimumModelConfig], env_config: EnvConfig
):
if isinstance(config, AdapterModelConfig):
model = AdapterModel(config=config, env_config=env_config)
elif isinstance(config, DeltaModelConfig):
model = DeltaModel(config=config, env_config=env_config)
elif isinstance(config, OptimumModelConfig):
model = OptimumModel(config=config, env_config=env_config)
else:
model = BaseModel(config=config, env_config=env_config)

Expand Down
30 changes: 30 additions & 0 deletions src/lighteval/models/optimum_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from lighteval.models.base_model import BaseModel
from lighteval.models.model_config import EnvConfig, OptimumModelConfig
from lighteval.models.utils import _get_dtype
from lighteval.utils import is_optimum_available


if is_optimum_available():
# from optimum import OptimumConfig
from optimum.intel.openvino import OVModelForCausalLM


class OptimumModel(BaseModel):
def _create_auto_model(self, config: OptimumModelConfig, env_config: EnvConfig):
# TODO : Get loading class from optimum config (add support for ORTModelForCausalLM / INCModelForCausalLM / IPEXModelForCausalLM)
# optimum_config = OptimumConfig.from_pretrained(config.pretrained)

config.model_parallel, max_memory, device_map = self.init_model_parallel(config.model_parallel)
torch_dtype = _get_dtype(config.dtype, self._config)

model = OVModelForCausalLM.from_pretrained(
config.pretrained,
revision=config.revision + (f"/{config.subfolder}" if config.subfolder is not None else ""),
torch_dtype=torch_dtype,
trust_remote_code=config.trust_remote_code,
cache_dir=env_config.cache_dir,
use_auth_token=env_config.token,
quantization_config=config.quantization_config,
)

return model
3 changes: 3 additions & 0 deletions src/lighteval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ def is_optimum_available() -> bool:
return importlib.util.find_spec("optimum") is not None


NO_OPTIMUM_ERROR_MSG = "You are trying to load a model with `optimum`, which is not available in your local environement. Please install it using pip."


def is_bnb_available() -> bool:
return importlib.util.find_spec("bitsandbytes") is not None

Expand Down
Loading