Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Nov 26, 2024
1 parent f4ee5b4 commit 7c7d729
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 17 deletions.
7 changes: 2 additions & 5 deletions optimum_benchmark/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,11 @@ def __init__(self, config: BackendConfigT):

elif self.config.library == "llama_cpp":
self.logger.info("\t+ Benchmarking a LlamaCpp model")
# TOD: need a custom method to extract shapes from gguf
self.model_shapes = extract_transformers_shapes_from_artifacts(
self.pretrained_config, self.pretrained_processor
)
self.pretrained_processor = None
self.generation_config = None
self.pretrained_config = None
self.generation_config = None
self.automodel_loader = None
self.model_shapes = {}

else:
self.logger.info("\t+ Benchmarking a Transformers model")
Expand Down
9 changes: 2 additions & 7 deletions optimum_benchmark/backends/llama_cpp/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,10 @@ def llama_cpp_kwargs(self) -> Dict[str, Any]:
"echo": False,
}

def prepare_input_shapes(self, input_shapes: Dict[str, Any]) -> Dict[str, Any]:
if self.config.task == "text-generation":
if input_shapes["batch_size"] != 1:
raise ValueError("Batch size must be 1 for LlamaCpp text generation")

return input_shapes

def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.config.task == "text-generation":
if inputs["input_ids"].shape[0] != 1:
raise ValueError("Batch size must be 1 for LlamaCpp text generation")
return {"tokens": inputs["input_ids"].squeeze(0).tolist()}

elif self.config.task == "feature-extraction":
Expand Down
12 changes: 7 additions & 5 deletions optimum_benchmark/launchers/torchrun/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ class TorchrunLauncher(Launcher[TorchrunConfig]):
def __init__(self, config: TorchrunConfig):
super().__init__(config)

if sys.platform == "win32":
self.logger.info("\t+ Disabline libuv on Windows")
os.environ["USE_LIBUV"] = "0"

if get_start_method(allow_none=True) != self.config.start_method:
self.logger.info(f"\t+ Setting multiprocessing start method to {self.config.start_method}")
set_start_method(self.config.start_method, force=True)
Expand Down Expand Up @@ -164,8 +160,14 @@ def entrypoint(worker: Callable[..., BenchmarkReport], worker_args: List[Any], l
device = torch.device("cuda", rank)
torch.cuda.set_device(device)

if sys.platform == "win32":
logger.info("\t+ Disabling libuv for Windows")
init_method = "env://?use_libuv=0"
else:
init_method = "env://"

logger.info("\t+ Initializing torch.distributed process group")
torch.distributed.init_process_group()
torch.distributed.init_process_group(init_method=init_method)

try:
report = worker(*worker_args)
Expand Down

0 comments on commit 7c7d729

Please sign in to comment.