diff --git a/lmms_eval/models/internvl2.py b/lmms_eval/models/internvl2.py index 5f4365d0..509df583 100644 --- a/lmms_eval/models/internvl2.py +++ b/lmms_eval/models/internvl2.py @@ -145,6 +145,7 @@ def __init__( accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) + self.accelerator = accelerator if accelerator.num_processes > 1: self._device = torch.device(f"cuda:{accelerator.local_process_index}") self.device_map = f"cuda:{accelerator.local_process_index}" diff --git a/lmms_eval/models/llava.py b/lmms_eval/models/llava.py index 7d6420ba..8537f999 100755 --- a/lmms_eval/models/llava.py +++ b/lmms_eval/models/llava.py @@ -69,6 +69,7 @@ def __init__( accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) + self.accelerator = accelerator if accelerator.num_processes > 1: self._device = torch.device(f"cuda:{accelerator.local_process_index}") self.device_map = f"cuda:{accelerator.local_process_index}" diff --git a/lmms_eval/models/longva.py b/lmms_eval/models/longva.py index c5bf6861..040c551d 100644 --- a/lmms_eval/models/longva.py +++ b/lmms_eval/models/longva.py @@ -76,6 +76,7 @@ def __init__( accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) + self.accelerator = accelerator if accelerator.num_processes > 1: self._device = torch.device(f"cuda:{accelerator.local_process_index}") self.device_map = f"cuda:{accelerator.local_process_index}"