diff --git a/optimum/exporters/neuron/__main__.py b/optimum/exporters/neuron/__main__.py index e8f27d66b..ce7d9a67b 100644 --- a/optimum/exporters/neuron/__main__.py +++ b/optimum/exporters/neuron/__main__.py @@ -18,6 +18,7 @@ import inspect import os from argparse import ArgumentParser +from dataclasses import fields from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union @@ -364,7 +365,8 @@ def get_submodels_and_neuron_configs( library_name=library_name, ) input_shapes = check_mandatory_input_shapes(neuron_config_constructor, task, input_shapes) - neuron_config = neuron_config_constructor(model.config, dynamic_batch_size=dynamic_batch_size, **input_shapes) + input_shapes = InputShapesArguments(**input_shapes) + neuron_config = neuron_config_constructor(model.config, dynamic_batch_size=dynamic_batch_size, input_shapes=input_shapes) model_name = getattr(model, "name_or_path", None) or model_name_or_path model_name = model_name.split("/")[-1] if model_name else model.config.model_type output_model_names = {model_name: "model.neuron"} @@ -507,11 +509,11 @@ def load_models_and_neuron_configs( local_files_only: bool, token: Optional[Union[bool, str]], submodels: Optional[Dict[str, Union[Path, str]]], - ip_adapter_args: IPAdapterArguments, torch_dtype: Optional[Union[str, torch.dtype]] = None, tensor_parallel_size: int = 1, controlnet_ids: Optional[Union[str, List[str]]] = None, lora_args: Optional[LoRAAdapterArguments] = None, + ip_adapter_args: Optional[IPAdapterArguments] = None, output_attentions: bool = False, output_hidden_states: bool = False, **input_shapes, @@ -533,7 +535,7 @@ def load_models_and_neuron_configs( if model is None: model = TasksManager.get_model_from_task(**model_kwargs) # Load IP-Adapter if it exists - if ip_adapter_args.model_id is not None: + if ip_adapter_args is not None and not all(getattr(ip_adapter_args, field.name) is None for field in fields(ip_adapter_args)): model.load_ip_adapter( ip_adapter_args.model_id, subfolder=ip_adapter_args.subfolder, weight_name=ip_adapter_args.weight_name )