Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Feb 14, 2025
1 parent 84bfa03 commit 9bba0ea
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import inspect
import os
from argparse import ArgumentParser
from dataclasses import fields
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

Expand Down Expand Up @@ -364,7 +365,8 @@ def get_submodels_and_neuron_configs(
library_name=library_name,
)
input_shapes = check_mandatory_input_shapes(neuron_config_constructor, task, input_shapes)
neuron_config = neuron_config_constructor(model.config, dynamic_batch_size=dynamic_batch_size, **input_shapes)
input_shapes = InputShapesArguments(**input_shapes)
neuron_config = neuron_config_constructor(model.config, dynamic_batch_size=dynamic_batch_size, input_shapes=input_shapes)
model_name = getattr(model, "name_or_path", None) or model_name_or_path
model_name = model_name.split("/")[-1] if model_name else model.config.model_type
output_model_names = {model_name: "model.neuron"}
Expand Down Expand Up @@ -507,11 +509,11 @@ def load_models_and_neuron_configs(
local_files_only: bool,
token: Optional[Union[bool, str]],
submodels: Optional[Dict[str, Union[Path, str]]],
ip_adapter_args: IPAdapterArguments,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
tensor_parallel_size: int = 1,
controlnet_ids: Optional[Union[str, List[str]]] = None,
lora_args: Optional[LoRAAdapterArguments] = None,
ip_adapter_args: Optional[IPAdapterArguments] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
**input_shapes,
Expand All @@ -533,7 +535,7 @@ def load_models_and_neuron_configs(
if model is None:
model = TasksManager.get_model_from_task(**model_kwargs)
# Load IP-Adapter if it exists
if ip_adapter_args.model_id is not None:
if ip_adapter_args is not None and not all(getattr(ip_adapter_args, field.name) is None for field in fields(ip_adapter_args)):
model.load_ip_adapter(
ip_adapter_args.model_id, subfolder=ip_adapter_args.subfolder, weight_name=ip_adapter_args.weight_name
)
Expand Down

0 comments on commit 9bba0ea

Please sign in to comment.