Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Nov 26, 2024
1 parent f4ee5b4 commit d7e589c
Showing 1 changed file with 9 additions and 15 deletions.
24 changes: 9 additions & 15 deletions optimum_benchmark/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class Backend(Generic[BackendConfigT], ABC):
pretrained_processor: Optional[PretrainedProcessor]

def __init__(self, config: BackendConfigT):
assert config.name == self.NAME, f"Backend name {self.NAME} doesn't match config name {config.name}"

self.config = config

self.logger = getLogger(self.NAME)
Expand All @@ -70,14 +72,13 @@ def __init__(self, config: BackendConfigT):

elif self.config.library == "llama_cpp":
self.logger.info("\t+ Benchmarking a LlamaCpp model")
# TOD: need a custom method to extract shapes from gguf
self.automodel_loader = None
self.pretrained_config = None
self.generation_config = None
self.pretrained_processor = None
self.model_shapes = extract_transformers_shapes_from_artifacts(
self.pretrained_config, self.pretrained_processor
)
self.pretrained_processor = None
self.generation_config = None
self.pretrained_config = None
self.automodel_loader = None

else:
self.logger.info("\t+ Benchmarking a Transformers model")
Expand Down Expand Up @@ -109,18 +110,11 @@ def create_no_weights_model(self) -> None:
self.logger.info("\t+ Saving no weights model's config")
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)

def prepare_input_shapes(self, input_shapes: Dict[str, Any]) -> Dict[str, Any]:
"""
This method is used to prepare and register the input shapes before using them by the model.
It can be used to pad the inputs to the correct shape, or compile it to the correct format.
"""
return input_shapes
@property
def split_between_processes(self) -> bool:
return False

def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""
This method is used to prepare and register the inputs before passing them to the model.
It can be used to move the inputs to the correct device, or rename their keys.
"""
return inputs

def load(self) -> None:
Expand Down

0 comments on commit d7e589c

Please sign in to comment.